Merge pull request #42 from lenianiva/feat/search
feat: Add linage info to tree search
This commit is contained in:
commit
92a50b00cc
|
@ -1,7 +1,7 @@
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional, Self
|
||||||
import collections, unittest
|
import collections, unittest
|
||||||
|
|
||||||
from pantograph.server import Server, TacticFailure, ServerError
|
from pantograph.server import Server, TacticFailure, ServerError
|
||||||
|
@ -11,16 +11,16 @@ from pantograph.expr import Expr, Tactic, GoalState
|
||||||
@dataclass
|
@dataclass
|
||||||
class SearchState:
|
class SearchState:
|
||||||
|
|
||||||
state: GoalState
|
goal_state: GoalState
|
||||||
parent: Optional[int]
|
parent: Optional[Self]
|
||||||
parent_goal_id: Optional[int]
|
parent_goal_id: Optional[int]
|
||||||
priorities: list[float]
|
priorities: list[float]
|
||||||
tactic_feedback: Optional[str] = None
|
tactic_feedback: Optional[str] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
assert len(self.priorities) == len(self.state.goals)
|
assert len(self.priorities) == len(self.goal_state.goals)
|
||||||
self.solved = [False for _ in self.state.goals]
|
self.solved = [False for _ in self.goal_state.goals]
|
||||||
self.trials = [0 for _ in self.state.goals]
|
self.trials = [0 for _ in self.goal_state.goals]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def next_goal_id(self) -> int:
|
def next_goal_id(self) -> int:
|
||||||
|
@ -89,7 +89,7 @@ class Agent:
|
||||||
time_start = time.time()
|
time_start = time.time()
|
||||||
|
|
||||||
initial_state = SearchState(
|
initial_state = SearchState(
|
||||||
state=goal_state,
|
goal_state,
|
||||||
parent=None,
|
parent=None,
|
||||||
parent_goal_id=None,
|
parent_goal_id=None,
|
||||||
priorities=[0.0 for _ in goal_state.goals]
|
priorities=[0.0 for _ in goal_state.goals]
|
||||||
|
@ -120,7 +120,7 @@ class Agent:
|
||||||
tactic = None
|
tactic = None
|
||||||
else:
|
else:
|
||||||
# Generate tactic for this goal
|
# Generate tactic for this goal
|
||||||
tactic = self.next_tactic(search_state.state, goal_id)
|
tactic = self.next_tactic(search_state.goal_state, goal_id)
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"Next tactic: {tactic}")
|
print(f"Next tactic: {tactic}")
|
||||||
|
@ -143,18 +143,18 @@ class Agent:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
search_state.trials[goal_id] += 1
|
search_state.trials[goal_id] += 1
|
||||||
state = search_state.state
|
goal_state = search_state.goal_state
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"{state.state_id}.{goal_id}: {tactic} on {search_state.state.goals[goal_id]}")
|
print(f"{goal_state.state_id}.{goal_id}: {tactic} on {goal_state.goals[goal_id]}")
|
||||||
next_goal_state = server.goal_tactic(search_state.state, goal_id, tactic)
|
next_goal_state = server.goal_tactic(goal_state, goal_id, tactic)
|
||||||
# Generate priorities for the next goal state
|
# Generate priorities for the next goal state
|
||||||
priorities = [0.0 for _ in next_goal_state.goals] \
|
priorities = [0.0 for _ in next_goal_state.goals] \
|
||||||
if len(next_goal_state.goals) <= 1 else \
|
if len(next_goal_state.goals) <= 1 else \
|
||||||
self.guidance(next_goal_state)
|
self.guidance(next_goal_state)
|
||||||
parent = len(search_stack) - 1
|
parent = len(search_stack) - 1
|
||||||
next_state = SearchState(
|
next_state = SearchState(
|
||||||
state=next_goal_state,
|
goal_state=next_goal_state,
|
||||||
parent=parent,
|
parent=search_state,
|
||||||
parent_goal_id=goal_id,
|
parent_goal_id=goal_id,
|
||||||
priorities=priorities
|
priorities=priorities
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue