feat: Tactic feedback per state
This commit is contained in:
parent
c6358056fe
commit
2e377d2a7e
|
@ -2,8 +2,7 @@ import random
|
|||
from abc import abstractmethod
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, List
|
||||
from typing_extensions import Self
|
||||
from typing import Optional, Self, List
|
||||
import collections, unittest
|
||||
from math import log, sqrt
|
||||
from pantograph.server import Server, TacticFailure, ServerError
|
||||
|
@ -12,25 +11,26 @@ from pantograph.expr import Expr, Tactic, GoalState
|
|||
|
||||
@dataclass
|
||||
class SearchState:
|
||||
|
||||
state: GoalState
|
||||
parent: Optional[int]
|
||||
goal_state: GoalState
|
||||
parent: Optional[Self]
|
||||
parent_goal_id: Optional[int]
|
||||
priorities: list[float]
|
||||
children: Optional[List[Self]] = None
|
||||
tested_tactics: Optional[List[Tactic]] = None
|
||||
total_value: Optional[float] = None
|
||||
tactic_feedback: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
assert len(self.priorities) == len(self.state.goals)
|
||||
self.solved = [False for _ in self.state.goals]
|
||||
self.trials = [0 for _ in self.state.goals]
|
||||
assert len(self.priorities) == len(self.goal_state.goals)
|
||||
self.solved = [False for _ in self.goal_state.goals]
|
||||
self.trials = [0 for _ in self.goal_state.goals]
|
||||
self.tested_tactics = [] if self.tested_tactics is None else self.tested_tactics
|
||||
self.children = [] if self.children is None else self.children
|
||||
self.visit_count = 1
|
||||
self.exhausted = False
|
||||
self.subtree_exhausted = False
|
||||
|
||||
|
||||
@property
|
||||
def next_goal_id(self) -> int:
|
||||
goal_id, _ = max(
|
||||
|
@ -58,7 +58,6 @@ class Agent:
|
|||
"""
|
||||
An agent interface for proof search
|
||||
"""
|
||||
tactic_feedback: Optional[str] = None
|
||||
|
||||
@abstractmethod
|
||||
def next_tactic(
|
||||
|
@ -99,7 +98,7 @@ class Agent:
|
|||
time_start = time.time()
|
||||
|
||||
initial_state = SearchState(
|
||||
state=goal_state,
|
||||
goal_state,
|
||||
parent=None,
|
||||
parent_goal_id=None,
|
||||
priorities=[0.0 for _ in goal_state.goals]
|
||||
|
@ -130,13 +129,13 @@ class Agent:
|
|||
tactic = None
|
||||
else:
|
||||
# Generate tactic for this goal
|
||||
tactic = self.next_tactic(search_state.state, goal_id)
|
||||
tactic = self.next_tactic(search_state.goal_state, goal_id)
|
||||
|
||||
if verbose:
|
||||
print(f"Next tactic: {tactic}")
|
||||
if not tactic:
|
||||
# resets the feedback
|
||||
self.tactic_feedback = None
|
||||
search_state.tactic_feedback = None
|
||||
# pop the current state and continue to the next
|
||||
search_stack.pop(-1)
|
||||
if not search_stack:
|
||||
|
@ -153,18 +152,18 @@ class Agent:
|
|||
|
||||
try:
|
||||
search_state.trials[goal_id] += 1
|
||||
state = search_state.state
|
||||
goal_state = search_state.goal_state
|
||||
if verbose:
|
||||
print(f"{state.state_id}.{goal_id}: {tactic} on {search_state.state.goals[goal_id]}")
|
||||
next_goal_state = server.goal_tactic(search_state.state, goal_id, tactic)
|
||||
print(f"{goal_state.state_id}.{goal_id}: {tactic} on {goal_state.goals[goal_id]}")
|
||||
next_goal_state = server.goal_tactic(goal_state, goal_id, tactic)
|
||||
# Generate priorities for the next goal state
|
||||
priorities = [0.0 for _ in next_goal_state.goals] \
|
||||
if len(next_goal_state.goals) <= 1 else \
|
||||
self.guidance(next_goal_state)
|
||||
parent = len(search_stack) - 1
|
||||
next_state = SearchState(
|
||||
state=next_goal_state,
|
||||
parent=parent,
|
||||
goal_state=next_goal_state,
|
||||
parent=search_state,
|
||||
parent_goal_id=goal_id,
|
||||
priorities=priorities
|
||||
)
|
||||
|
@ -173,7 +172,7 @@ class Agent:
|
|||
except TacticFailure as t:
|
||||
if verbose:
|
||||
print(f"Tactic failed: {t}")
|
||||
self.tactic_feedback = str(t)
|
||||
search_state.tactic_feedback = str(t)
|
||||
# try the next tactic. this one failed
|
||||
except ServerError as e:
|
||||
raise RuntimeError(f"While executing tactic: {tactic}") from e
|
||||
|
@ -250,7 +249,7 @@ class MCTSAgent(Agent):
|
|||
time_start = time.time()
|
||||
|
||||
initial_state = SearchState(
|
||||
state=goal_state,
|
||||
goal_state=goal_state,
|
||||
parent=None,
|
||||
parent_goal_id=None,
|
||||
priorities=[0.0 for _ in goal_state.goals]
|
||||
|
@ -279,13 +278,13 @@ class MCTSAgent(Agent):
|
|||
tactic = None
|
||||
else:
|
||||
# Generate tactic for this goal
|
||||
tactic = self.next_tactic(search_state.state, goal_id, search_state.tested_tactics)
|
||||
tactic = self.next_tactic(search_state.goal_state, goal_id, search_state.tested_tactics)
|
||||
|
||||
if verbose:
|
||||
print(f"Next tactic: {tactic}")
|
||||
if not tactic:
|
||||
# resets the feedback
|
||||
self.tactic_feedback = None
|
||||
search_state.tactic_feedback = None
|
||||
search_state.exhausted = True
|
||||
search_state.subtree_exhausted = all(child.subtree_exhausted for child in search_state.children)
|
||||
continue
|
||||
|
@ -294,17 +293,17 @@ class MCTSAgent(Agent):
|
|||
|
||||
try:
|
||||
search_state.trials[goal_id] += 1
|
||||
state = search_state.state
|
||||
state = search_state.goal_state
|
||||
if verbose:
|
||||
print(f"{state.state_id}.{goal_id}: {tactic} on {search_state.state.goals[goal_id]}")
|
||||
next_goal_state = server.goal_tactic(search_state.state, goal_id, tactic)
|
||||
print(f"{state.state_id}.{goal_id}: {tactic} on {search_state.goal_state.goals[goal_id]}")
|
||||
next_goal_state = server.goal_tactic(search_state.goal_state, goal_id, tactic)
|
||||
# Generate priorities for the next goal state
|
||||
priorities = [0.0 for _ in next_goal_state.goals] \
|
||||
if len(next_goal_state.goals) <= 1 else \
|
||||
self.guidance(next_goal_state)
|
||||
parent = -1
|
||||
next_state = SearchState(
|
||||
state=next_goal_state,
|
||||
goal_state=next_goal_state,
|
||||
parent=parent,
|
||||
parent_goal_id=goal_id,
|
||||
priorities=priorities
|
||||
|
@ -315,7 +314,7 @@ class MCTSAgent(Agent):
|
|||
except TacticFailure as t:
|
||||
if verbose:
|
||||
print(f"Tactic failed: {t}")
|
||||
self.tactic_feedback = str(t) # This should most likely be feedback per node
|
||||
search_state.tactic_feedback = str(t)
|
||||
# try the next tactic. this one failed
|
||||
except ServerError as e:
|
||||
raise RuntimeError(f"While executing tactic: {tactic}") from e
|
||||
|
|
Loading…
Reference in New Issue