Merge pull request #42 from lenianiva/feat/search

feat: Add linage info to tree search
This commit is contained in:
Leni Aniva 2024-10-30 18:31:42 -07:00 committed by GitHub
commit 92a50b00cc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 13 additions and 13 deletions

View File

@ -1,7 +1,7 @@
from abc import abstractmethod from abc import abstractmethod
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional, Self
import collections, unittest import collections, unittest
from pantograph.server import Server, TacticFailure, ServerError from pantograph.server import Server, TacticFailure, ServerError
@ -11,16 +11,16 @@ from pantograph.expr import Expr, Tactic, GoalState
@dataclass @dataclass
class SearchState: class SearchState:
state: GoalState goal_state: GoalState
parent: Optional[int] parent: Optional[Self]
parent_goal_id: Optional[int] parent_goal_id: Optional[int]
priorities: list[float] priorities: list[float]
tactic_feedback: Optional[str] = None tactic_feedback: Optional[str] = None
def __post_init__(self): def __post_init__(self):
assert len(self.priorities) == len(self.state.goals) assert len(self.priorities) == len(self.goal_state.goals)
self.solved = [False for _ in self.state.goals] self.solved = [False for _ in self.goal_state.goals]
self.trials = [0 for _ in self.state.goals] self.trials = [0 for _ in self.goal_state.goals]
@property @property
def next_goal_id(self) -> int: def next_goal_id(self) -> int:
@ -89,7 +89,7 @@ class Agent:
time_start = time.time() time_start = time.time()
initial_state = SearchState( initial_state = SearchState(
state=goal_state, goal_state,
parent=None, parent=None,
parent_goal_id=None, parent_goal_id=None,
priorities=[0.0 for _ in goal_state.goals] priorities=[0.0 for _ in goal_state.goals]
@ -120,7 +120,7 @@ class Agent:
tactic = None tactic = None
else: else:
# Generate tactic for this goal # Generate tactic for this goal
tactic = self.next_tactic(search_state.state, goal_id) tactic = self.next_tactic(search_state.goal_state, goal_id)
if verbose: if verbose:
print(f"Next tactic: {tactic}") print(f"Next tactic: {tactic}")
@ -143,18 +143,18 @@ class Agent:
try: try:
search_state.trials[goal_id] += 1 search_state.trials[goal_id] += 1
state = search_state.state goal_state = search_state.goal_state
if verbose: if verbose:
print(f"{state.state_id}.{goal_id}: {tactic} on {search_state.state.goals[goal_id]}") print(f"{goal_state.state_id}.{goal_id}: {tactic} on {goal_state.goals[goal_id]}")
next_goal_state = server.goal_tactic(search_state.state, goal_id, tactic) next_goal_state = server.goal_tactic(goal_state, goal_id, tactic)
# Generate priorities for the next goal state # Generate priorities for the next goal state
priorities = [0.0 for _ in next_goal_state.goals] \ priorities = [0.0 for _ in next_goal_state.goals] \
if len(next_goal_state.goals) <= 1 else \ if len(next_goal_state.goals) <= 1 else \
self.guidance(next_goal_state) self.guidance(next_goal_state)
parent = len(search_stack) - 1 parent = len(search_stack) - 1
next_state = SearchState( next_state = SearchState(
state=next_goal_state, goal_state=next_goal_state,
parent=parent, parent=search_state,
parent_goal_id=goal_id, parent_goal_id=goal_id,
priorities=priorities priorities=priorities
) )