From 2e377d2a7ee4f22f6bdd9da7fa99dace9807419c Mon Sep 17 00:00:00 2001 From: Simon <80467011+sorgfresser@users.noreply.github.com> Date: Sat, 2 Nov 2024 10:35:57 +0000 Subject: [PATCH] feat: Tactic feedback per state --- pantograph/search.py | 51 ++++++++++++++++++++++---------------------- 1 file changed, 25 insertions(+), 26 deletions(-) diff --git a/pantograph/search.py b/pantograph/search.py index eba37ab..679223b 100644 --- a/pantograph/search.py +++ b/pantograph/search.py @@ -2,8 +2,7 @@ import random from abc import abstractmethod import time from dataclasses import dataclass -from typing import Optional, List -from typing_extensions import Self +from typing import Optional, Self, List import collections, unittest from math import log, sqrt from pantograph.server import Server, TacticFailure, ServerError @@ -12,25 +11,26 @@ from pantograph.expr import Expr, Tactic, GoalState @dataclass class SearchState: - - state: GoalState - parent: Optional[int] + goal_state: GoalState + parent: Optional[Self] parent_goal_id: Optional[int] priorities: list[float] children: Optional[List[Self]] = None tested_tactics: Optional[List[Tactic]] = None total_value: Optional[float] = None + tactic_feedback: Optional[str] = None def __post_init__(self): - assert len(self.priorities) == len(self.state.goals) - self.solved = [False for _ in self.state.goals] - self.trials = [0 for _ in self.state.goals] + assert len(self.priorities) == len(self.goal_state.goals) + self.solved = [False for _ in self.goal_state.goals] + self.trials = [0 for _ in self.goal_state.goals] self.tested_tactics = [] if self.tested_tactics is None else self.tested_tactics self.children = [] if self.children is None else self.children self.visit_count = 1 self.exhausted = False self.subtree_exhausted = False + @property def next_goal_id(self) -> int: goal_id, _ = max( @@ -58,7 +58,6 @@ class Agent: """ An agent interface for proof search """ - tactic_feedback: Optional[str] = None @abstractmethod def next_tactic( @@ -99,7 +98,7 @@ class Agent: time_start = time.time() initial_state = SearchState( - state=goal_state, + goal_state, parent=None, parent_goal_id=None, priorities=[0.0 for _ in goal_state.goals] @@ -130,13 +129,13 @@ class Agent: tactic = None else: # Generate tactic for this goal - tactic = self.next_tactic(search_state.state, goal_id) + tactic = self.next_tactic(search_state.goal_state, goal_id) if verbose: print(f"Next tactic: {tactic}") if not tactic: # resets the feedback - self.tactic_feedback = None + search_state.tactic_feedback = None # pop the current state and continue to the next search_stack.pop(-1) if not search_stack: @@ -153,18 +152,18 @@ class Agent: try: search_state.trials[goal_id] += 1 - state = search_state.state + goal_state = search_state.goal_state if verbose: - print(f"{state.state_id}.{goal_id}: {tactic} on {search_state.state.goals[goal_id]}") - next_goal_state = server.goal_tactic(search_state.state, goal_id, tactic) + print(f"{goal_state.state_id}.{goal_id}: {tactic} on {goal_state.goals[goal_id]}") + next_goal_state = server.goal_tactic(goal_state, goal_id, tactic) # Generate priorities for the next goal state priorities = [0.0 for _ in next_goal_state.goals] \ if len(next_goal_state.goals) <= 1 else \ self.guidance(next_goal_state) parent = len(search_stack) - 1 next_state = SearchState( - state=next_goal_state, - parent=parent, + goal_state=next_goal_state, + parent=search_state, parent_goal_id=goal_id, priorities=priorities ) @@ -173,7 +172,7 @@ class Agent: except TacticFailure as t: if verbose: print(f"Tactic failed: {t}") - self.tactic_feedback = str(t) + search_state.tactic_feedback = str(t) # try the next tactic. this one failed except ServerError as e: raise RuntimeError(f"While executing tactic: {tactic}") from e @@ -250,7 +249,7 @@ class MCTSAgent(Agent): time_start = time.time() initial_state = SearchState( - state=goal_state, + goal_state=goal_state, parent=None, parent_goal_id=None, priorities=[0.0 for _ in goal_state.goals] @@ -279,13 +278,13 @@ class MCTSAgent(Agent): tactic = None else: # Generate tactic for this goal - tactic = self.next_tactic(search_state.state, goal_id, search_state.tested_tactics) + tactic = self.next_tactic(search_state.goal_state, goal_id, search_state.tested_tactics) if verbose: print(f"Next tactic: {tactic}") if not tactic: # resets the feedback - self.tactic_feedback = None + search_state.tactic_feedback = None search_state.exhausted = True search_state.subtree_exhausted = all(child.subtree_exhausted for child in search_state.children) continue @@ -294,17 +293,17 @@ class MCTSAgent(Agent): try: search_state.trials[goal_id] += 1 - state = search_state.state + state = search_state.goal_state if verbose: - print(f"{state.state_id}.{goal_id}: {tactic} on {search_state.state.goals[goal_id]}") - next_goal_state = server.goal_tactic(search_state.state, goal_id, tactic) + print(f"{state.state_id}.{goal_id}: {tactic} on {search_state.goal_state.goals[goal_id]}") + next_goal_state = server.goal_tactic(search_state.goal_state, goal_id, tactic) # Generate priorities for the next goal state priorities = [0.0 for _ in next_goal_state.goals] \ if len(next_goal_state.goals) <= 1 else \ self.guidance(next_goal_state) parent = -1 next_state = SearchState( - state=next_goal_state, + goal_state=next_goal_state, parent=parent, parent_goal_id=goal_id, priorities=priorities @@ -315,7 +314,7 @@ class MCTSAgent(Agent): except TacticFailure as t: if verbose: print(f"Tactic failed: {t}") - self.tactic_feedback = str(t) # This should most likely be feedback per node + search_state.tactic_feedback = str(t) # try the next tactic. this one failed except ServerError as e: raise RuntimeError(f"While executing tactic: {tactic}") from e