feat: Add linage info to tree search
This commit is contained in:
parent
93ecd0d5ad
commit
de93309393
|
@ -1,7 +1,7 @@
|
|||
from abc import abstractmethod
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from typing import Optional, Self
|
||||
import collections, unittest
|
||||
|
||||
from pantograph.server import Server, TacticFailure, ServerError
|
||||
|
@ -11,15 +11,15 @@ from pantograph.expr import Expr, Tactic, GoalState
|
|||
@dataclass
|
||||
class SearchState:
|
||||
|
||||
state: GoalState
|
||||
parent: Optional[int]
|
||||
goal_state: GoalState
|
||||
parent: Optional[Self]
|
||||
parent_goal_id: Optional[int]
|
||||
priorities: list[float]
|
||||
|
||||
def __post_init__(self):
|
||||
assert len(self.priorities) == len(self.state.goals)
|
||||
self.solved = [False for _ in self.state.goals]
|
||||
self.trials = [0 for _ in self.state.goals]
|
||||
assert len(self.priorities) == len(self.goal_state.goals)
|
||||
self.solved = [False for _ in self.goal_state.goals]
|
||||
self.trials = [0 for _ in self.goal_state.goals]
|
||||
|
||||
@property
|
||||
def next_goal_id(self) -> int:
|
||||
|
@ -89,7 +89,7 @@ class Agent:
|
|||
time_start = time.time()
|
||||
|
||||
initial_state = SearchState(
|
||||
state=goal_state,
|
||||
goal_state,
|
||||
parent=None,
|
||||
parent_goal_id=None,
|
||||
priorities=[0.0 for _ in goal_state.goals]
|
||||
|
@ -120,7 +120,7 @@ class Agent:
|
|||
tactic = None
|
||||
else:
|
||||
# Generate tactic for this goal
|
||||
tactic = self.next_tactic(search_state.state, goal_id)
|
||||
tactic = self.next_tactic(search_state.goal_state, goal_id)
|
||||
|
||||
if verbose:
|
||||
print(f"Next tactic: {tactic}")
|
||||
|
@ -143,18 +143,18 @@ class Agent:
|
|||
|
||||
try:
|
||||
search_state.trials[goal_id] += 1
|
||||
state = search_state.state
|
||||
goal_state = search_state.goal_state
|
||||
if verbose:
|
||||
print(f"{state.state_id}.{goal_id}: {tactic} on {search_state.state.goals[goal_id]}")
|
||||
next_goal_state = server.goal_tactic(search_state.state, goal_id, tactic)
|
||||
print(f"{goal_state.state_id}.{goal_id}: {tactic} on {goal_state.goals[goal_id]}")
|
||||
next_goal_state = server.goal_tactic(goal_state, goal_id, tactic)
|
||||
# Generate priorities for the next goal state
|
||||
priorities = [0.0 for _ in next_goal_state.goals] \
|
||||
if len(next_goal_state.goals) <= 1 else \
|
||||
self.guidance(next_goal_state)
|
||||
parent = len(search_stack) - 1
|
||||
next_state = SearchState(
|
||||
state=next_goal_state,
|
||||
parent=parent,
|
||||
goal_state=next_goal_state,
|
||||
parent=search_state,
|
||||
parent_goal_id=goal_id,
|
||||
priorities=priorities
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue