diff --git a/pantograph/search.py b/pantograph/search.py index f6c98ef..2b25685 100644 --- a/pantograph/search.py +++ b/pantograph/search.py @@ -1,7 +1,7 @@ from abc import abstractmethod import time from dataclasses import dataclass -from typing import Optional +from typing import Optional, Self import collections, unittest from pantograph.server import Server, TacticFailure, ServerError @@ -11,16 +11,16 @@ from pantograph.expr import Expr, Tactic, GoalState @dataclass class SearchState: - state: GoalState - parent: Optional[int] + goal_state: GoalState + parent: Optional[Self] parent_goal_id: Optional[int] priorities: list[float] tactic_feedback: Optional[str] = None def __post_init__(self): - assert len(self.priorities) == len(self.state.goals) - self.solved = [False for _ in self.state.goals] - self.trials = [0 for _ in self.state.goals] + assert len(self.priorities) == len(self.goal_state.goals) + self.solved = [False for _ in self.goal_state.goals] + self.trials = [0 for _ in self.goal_state.goals] @property def next_goal_id(self) -> int: @@ -89,7 +89,7 @@ class Agent: time_start = time.time() initial_state = SearchState( - state=goal_state, + goal_state, parent=None, parent_goal_id=None, priorities=[0.0 for _ in goal_state.goals] @@ -120,7 +120,7 @@ class Agent: tactic = None else: # Generate tactic for this goal - tactic = self.next_tactic(search_state.state, goal_id) + tactic = self.next_tactic(search_state.goal_state, goal_id) if verbose: print(f"Next tactic: {tactic}") @@ -143,18 +143,18 @@ class Agent: try: search_state.trials[goal_id] += 1 - state = search_state.state + goal_state = search_state.goal_state if verbose: - print(f"{state.state_id}.{goal_id}: {tactic} on {search_state.state.goals[goal_id]}") - next_goal_state = server.goal_tactic(search_state.state, goal_id, tactic) + print(f"{goal_state.state_id}.{goal_id}: {tactic} on {goal_state.goals[goal_id]}") + next_goal_state = server.goal_tactic(goal_state, goal_id, tactic) # Generate priorities for the next goal state priorities = [0.0 for _ in next_goal_state.goals] \ if len(next_goal_state.goals) <= 1 else \ self.guidance(next_goal_state) parent = len(search_stack) - 1 next_state = SearchState( - state=next_goal_state, - parent=parent, + goal_state=next_goal_state, + parent=search_state, parent_goal_id=goal_id, priorities=priorities )