feat: Add linage info to tree search

This commit is contained in:
Leni Aniva 2024-10-28 10:34:09 -07:00
parent 93ecd0d5ad
commit de93309393
Signed by: aniva
GPG Key ID: 4D9B1C8D10EA4C50
1 changed files with 13 additions and 13 deletions

View File

@ -1,7 +1,7 @@
from abc import abstractmethod
import time
from dataclasses import dataclass
from typing import Optional
from typing import Optional, Self
import collections, unittest
from pantograph.server import Server, TacticFailure, ServerError
@ -11,15 +11,15 @@ from pantograph.expr import Expr, Tactic, GoalState
@dataclass
class SearchState:
state: GoalState
parent: Optional[int]
goal_state: GoalState
parent: Optional[Self]
parent_goal_id: Optional[int]
priorities: list[float]
def __post_init__(self):
assert len(self.priorities) == len(self.state.goals)
self.solved = [False for _ in self.state.goals]
self.trials = [0 for _ in self.state.goals]
assert len(self.priorities) == len(self.goal_state.goals)
self.solved = [False for _ in self.goal_state.goals]
self.trials = [0 for _ in self.goal_state.goals]
@property
def next_goal_id(self) -> int:
@ -89,7 +89,7 @@ class Agent:
time_start = time.time()
initial_state = SearchState(
state=goal_state,
goal_state,
parent=None,
parent_goal_id=None,
priorities=[0.0 for _ in goal_state.goals]
@ -120,7 +120,7 @@ class Agent:
tactic = None
else:
# Generate tactic for this goal
tactic = self.next_tactic(search_state.state, goal_id)
tactic = self.next_tactic(search_state.goal_state, goal_id)
if verbose:
print(f"Next tactic: {tactic}")
@ -143,18 +143,18 @@ class Agent:
try:
search_state.trials[goal_id] += 1
state = search_state.state
goal_state = search_state.goal_state
if verbose:
print(f"{state.state_id}.{goal_id}: {tactic} on {search_state.state.goals[goal_id]}")
next_goal_state = server.goal_tactic(search_state.state, goal_id, tactic)
print(f"{goal_state.state_id}.{goal_id}: {tactic} on {goal_state.goals[goal_id]}")
next_goal_state = server.goal_tactic(goal_state, goal_id, tactic)
# Generate priorities for the next goal state
priorities = [0.0 for _ in next_goal_state.goals] \
if len(next_goal_state.goals) <= 1 else \
self.guidance(next_goal_state)
parent = len(search_stack) - 1
next_state = SearchState(
state=next_goal_state,
parent=parent,
goal_state=next_goal_state,
parent=search_state,
parent_goal_id=goal_id,
priorities=priorities
)