feat: Tactic feedback per state

This commit is contained in:
Simon 2024-11-02 10:35:57 +00:00 committed by GitHub
parent c6358056fe
commit 2e377d2a7e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 25 additions and 26 deletions

View File

@ -2,8 +2,7 @@ import random
from abc import abstractmethod from abc import abstractmethod
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, List from typing import Optional, Self, List
from typing_extensions import Self
import collections, unittest import collections, unittest
from math import log, sqrt from math import log, sqrt
from pantograph.server import Server, TacticFailure, ServerError from pantograph.server import Server, TacticFailure, ServerError
@ -12,25 +11,26 @@ from pantograph.expr import Expr, Tactic, GoalState
@dataclass @dataclass
class SearchState: class SearchState:
goal_state: GoalState
state: GoalState parent: Optional[Self]
parent: Optional[int]
parent_goal_id: Optional[int] parent_goal_id: Optional[int]
priorities: list[float] priorities: list[float]
children: Optional[List[Self]] = None children: Optional[List[Self]] = None
tested_tactics: Optional[List[Tactic]] = None tested_tactics: Optional[List[Tactic]] = None
total_value: Optional[float] = None total_value: Optional[float] = None
tactic_feedback: Optional[str] = None
def __post_init__(self): def __post_init__(self):
assert len(self.priorities) == len(self.state.goals) assert len(self.priorities) == len(self.goal_state.goals)
self.solved = [False for _ in self.state.goals] self.solved = [False for _ in self.goal_state.goals]
self.trials = [0 for _ in self.state.goals] self.trials = [0 for _ in self.goal_state.goals]
self.tested_tactics = [] if self.tested_tactics is None else self.tested_tactics self.tested_tactics = [] if self.tested_tactics is None else self.tested_tactics
self.children = [] if self.children is None else self.children self.children = [] if self.children is None else self.children
self.visit_count = 1 self.visit_count = 1
self.exhausted = False self.exhausted = False
self.subtree_exhausted = False self.subtree_exhausted = False
@property @property
def next_goal_id(self) -> int: def next_goal_id(self) -> int:
goal_id, _ = max( goal_id, _ = max(
@ -58,7 +58,6 @@ class Agent:
""" """
An agent interface for proof search An agent interface for proof search
""" """
tactic_feedback: Optional[str] = None
@abstractmethod @abstractmethod
def next_tactic( def next_tactic(
@ -99,7 +98,7 @@ class Agent:
time_start = time.time() time_start = time.time()
initial_state = SearchState( initial_state = SearchState(
state=goal_state, goal_state,
parent=None, parent=None,
parent_goal_id=None, parent_goal_id=None,
priorities=[0.0 for _ in goal_state.goals] priorities=[0.0 for _ in goal_state.goals]
@ -130,13 +129,13 @@ class Agent:
tactic = None tactic = None
else: else:
# Generate tactic for this goal # Generate tactic for this goal
tactic = self.next_tactic(search_state.state, goal_id) tactic = self.next_tactic(search_state.goal_state, goal_id)
if verbose: if verbose:
print(f"Next tactic: {tactic}") print(f"Next tactic: {tactic}")
if not tactic: if not tactic:
# resets the feedback # resets the feedback
self.tactic_feedback = None search_state.tactic_feedback = None
# pop the current state and continue to the next # pop the current state and continue to the next
search_stack.pop(-1) search_stack.pop(-1)
if not search_stack: if not search_stack:
@ -153,18 +152,18 @@ class Agent:
try: try:
search_state.trials[goal_id] += 1 search_state.trials[goal_id] += 1
state = search_state.state goal_state = search_state.goal_state
if verbose: if verbose:
print(f"{state.state_id}.{goal_id}: {tactic} on {search_state.state.goals[goal_id]}") print(f"{goal_state.state_id}.{goal_id}: {tactic} on {goal_state.goals[goal_id]}")
next_goal_state = server.goal_tactic(search_state.state, goal_id, tactic) next_goal_state = server.goal_tactic(goal_state, goal_id, tactic)
# Generate priorities for the next goal state # Generate priorities for the next goal state
priorities = [0.0 for _ in next_goal_state.goals] \ priorities = [0.0 for _ in next_goal_state.goals] \
if len(next_goal_state.goals) <= 1 else \ if len(next_goal_state.goals) <= 1 else \
self.guidance(next_goal_state) self.guidance(next_goal_state)
parent = len(search_stack) - 1 parent = len(search_stack) - 1
next_state = SearchState( next_state = SearchState(
state=next_goal_state, goal_state=next_goal_state,
parent=parent, parent=search_state,
parent_goal_id=goal_id, parent_goal_id=goal_id,
priorities=priorities priorities=priorities
) )
@ -173,7 +172,7 @@ class Agent:
except TacticFailure as t: except TacticFailure as t:
if verbose: if verbose:
print(f"Tactic failed: {t}") print(f"Tactic failed: {t}")
self.tactic_feedback = str(t) search_state.tactic_feedback = str(t)
# try the next tactic. this one failed # try the next tactic. this one failed
except ServerError as e: except ServerError as e:
raise RuntimeError(f"While executing tactic: {tactic}") from e raise RuntimeError(f"While executing tactic: {tactic}") from e
@ -250,7 +249,7 @@ class MCTSAgent(Agent):
time_start = time.time() time_start = time.time()
initial_state = SearchState( initial_state = SearchState(
state=goal_state, goal_state=goal_state,
parent=None, parent=None,
parent_goal_id=None, parent_goal_id=None,
priorities=[0.0 for _ in goal_state.goals] priorities=[0.0 for _ in goal_state.goals]
@ -279,13 +278,13 @@ class MCTSAgent(Agent):
tactic = None tactic = None
else: else:
# Generate tactic for this goal # Generate tactic for this goal
tactic = self.next_tactic(search_state.state, goal_id, search_state.tested_tactics) tactic = self.next_tactic(search_state.goal_state, goal_id, search_state.tested_tactics)
if verbose: if verbose:
print(f"Next tactic: {tactic}") print(f"Next tactic: {tactic}")
if not tactic: if not tactic:
# resets the feedback # resets the feedback
self.tactic_feedback = None search_state.tactic_feedback = None
search_state.exhausted = True search_state.exhausted = True
search_state.subtree_exhausted = all(child.subtree_exhausted for child in search_state.children) search_state.subtree_exhausted = all(child.subtree_exhausted for child in search_state.children)
continue continue
@ -294,17 +293,17 @@ class MCTSAgent(Agent):
try: try:
search_state.trials[goal_id] += 1 search_state.trials[goal_id] += 1
state = search_state.state state = search_state.goal_state
if verbose: if verbose:
print(f"{state.state_id}.{goal_id}: {tactic} on {search_state.state.goals[goal_id]}") print(f"{state.state_id}.{goal_id}: {tactic} on {search_state.goal_state.goals[goal_id]}")
next_goal_state = server.goal_tactic(search_state.state, goal_id, tactic) next_goal_state = server.goal_tactic(search_state.goal_state, goal_id, tactic)
# Generate priorities for the next goal state # Generate priorities for the next goal state
priorities = [0.0 for _ in next_goal_state.goals] \ priorities = [0.0 for _ in next_goal_state.goals] \
if len(next_goal_state.goals) <= 1 else \ if len(next_goal_state.goals) <= 1 else \
self.guidance(next_goal_state) self.guidance(next_goal_state)
parent = -1 parent = -1
next_state = SearchState( next_state = SearchState(
state=next_goal_state, goal_state=next_goal_state,
parent=parent, parent=parent,
parent_goal_id=goal_id, parent_goal_id=goal_id,
priorities=priorities priorities=priorities
@ -315,7 +314,7 @@ class MCTSAgent(Agent):
except TacticFailure as t: except TacticFailure as t:
if verbose: if verbose:
print(f"Tactic failed: {t}") print(f"Tactic failed: {t}")
self.tactic_feedback = str(t) # This should most likely be feedback per node search_state.tactic_feedback = str(t)
# try the next tactic. this one failed # try the next tactic. this one failed
except ServerError as e: except ServerError as e:
raise RuntimeError(f"While executing tactic: {tactic}") from e raise RuntimeError(f"While executing tactic: {tactic}") from e