feat: Tactic feedback per state

This commit is contained in:
Simon 2024-11-02 10:35:57 +00:00 committed by GitHub
parent c6358056fe
commit 2e377d2a7e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 25 additions and 26 deletions

View File

@ -2,8 +2,7 @@ import random
from abc import abstractmethod
import time
from dataclasses import dataclass
from typing import Optional, List
from typing_extensions import Self
from typing import Optional, Self, List
import collections, unittest
from math import log, sqrt
from pantograph.server import Server, TacticFailure, ServerError
@ -12,25 +11,26 @@ from pantograph.expr import Expr, Tactic, GoalState
@dataclass
class SearchState:
state: GoalState
parent: Optional[int]
goal_state: GoalState
parent: Optional[Self]
parent_goal_id: Optional[int]
priorities: list[float]
children: Optional[List[Self]] = None
tested_tactics: Optional[List[Tactic]] = None
total_value: Optional[float] = None
tactic_feedback: Optional[str] = None
def __post_init__(self):
assert len(self.priorities) == len(self.state.goals)
self.solved = [False for _ in self.state.goals]
self.trials = [0 for _ in self.state.goals]
assert len(self.priorities) == len(self.goal_state.goals)
self.solved = [False for _ in self.goal_state.goals]
self.trials = [0 for _ in self.goal_state.goals]
self.tested_tactics = [] if self.tested_tactics is None else self.tested_tactics
self.children = [] if self.children is None else self.children
self.visit_count = 1
self.exhausted = False
self.subtree_exhausted = False
@property
def next_goal_id(self) -> int:
goal_id, _ = max(
@ -58,7 +58,6 @@ class Agent:
"""
An agent interface for proof search
"""
tactic_feedback: Optional[str] = None
@abstractmethod
def next_tactic(
@ -99,7 +98,7 @@ class Agent:
time_start = time.time()
initial_state = SearchState(
state=goal_state,
goal_state,
parent=None,
parent_goal_id=None,
priorities=[0.0 for _ in goal_state.goals]
@ -130,13 +129,13 @@ class Agent:
tactic = None
else:
# Generate tactic for this goal
tactic = self.next_tactic(search_state.state, goal_id)
tactic = self.next_tactic(search_state.goal_state, goal_id)
if verbose:
print(f"Next tactic: {tactic}")
if not tactic:
# resets the feedback
self.tactic_feedback = None
search_state.tactic_feedback = None
# pop the current state and continue to the next
search_stack.pop(-1)
if not search_stack:
@ -153,18 +152,18 @@ class Agent:
try:
search_state.trials[goal_id] += 1
state = search_state.state
goal_state = search_state.goal_state
if verbose:
print(f"{state.state_id}.{goal_id}: {tactic} on {search_state.state.goals[goal_id]}")
next_goal_state = server.goal_tactic(search_state.state, goal_id, tactic)
print(f"{goal_state.state_id}.{goal_id}: {tactic} on {goal_state.goals[goal_id]}")
next_goal_state = server.goal_tactic(goal_state, goal_id, tactic)
# Generate priorities for the next goal state
priorities = [0.0 for _ in next_goal_state.goals] \
if len(next_goal_state.goals) <= 1 else \
self.guidance(next_goal_state)
parent = len(search_stack) - 1
next_state = SearchState(
state=next_goal_state,
parent=parent,
goal_state=next_goal_state,
parent=search_state,
parent_goal_id=goal_id,
priorities=priorities
)
@ -173,7 +172,7 @@ class Agent:
except TacticFailure as t:
if verbose:
print(f"Tactic failed: {t}")
self.tactic_feedback = str(t)
search_state.tactic_feedback = str(t)
# try the next tactic. this one failed
except ServerError as e:
raise RuntimeError(f"While executing tactic: {tactic}") from e
@ -250,7 +249,7 @@ class MCTSAgent(Agent):
time_start = time.time()
initial_state = SearchState(
state=goal_state,
goal_state=goal_state,
parent=None,
parent_goal_id=None,
priorities=[0.0 for _ in goal_state.goals]
@ -279,13 +278,13 @@ class MCTSAgent(Agent):
tactic = None
else:
# Generate tactic for this goal
tactic = self.next_tactic(search_state.state, goal_id, search_state.tested_tactics)
tactic = self.next_tactic(search_state.goal_state, goal_id, search_state.tested_tactics)
if verbose:
print(f"Next tactic: {tactic}")
if not tactic:
# resets the feedback
self.tactic_feedback = None
search_state.tactic_feedback = None
search_state.exhausted = True
search_state.subtree_exhausted = all(child.subtree_exhausted for child in search_state.children)
continue
@ -294,17 +293,17 @@ class MCTSAgent(Agent):
try:
search_state.trials[goal_id] += 1
state = search_state.state
state = search_state.goal_state
if verbose:
print(f"{state.state_id}.{goal_id}: {tactic} on {search_state.state.goals[goal_id]}")
next_goal_state = server.goal_tactic(search_state.state, goal_id, tactic)
print(f"{state.state_id}.{goal_id}: {tactic} on {search_state.goal_state.goals[goal_id]}")
next_goal_state = server.goal_tactic(search_state.goal_state, goal_id, tactic)
# Generate priorities for the next goal state
priorities = [0.0 for _ in next_goal_state.goals] \
if len(next_goal_state.goals) <= 1 else \
self.guidance(next_goal_state)
parent = -1
next_state = SearchState(
state=next_goal_state,
goal_state=next_goal_state,
parent=parent,
parent_goal_id=goal_id,
priorities=priorities
@ -315,7 +314,7 @@ class MCTSAgent(Agent):
except TacticFailure as t:
if verbose:
print(f"Tactic failed: {t}")
self.tactic_feedback = str(t) # This should most likely be feedback per node
search_state.tactic_feedback = str(t)
# try the next tactic. this one failed
except ServerError as e:
raise RuntimeError(f"While executing tactic: {tactic}") from e