feat: Tactic feedback per state
This commit is contained in:
parent
c6358056fe
commit
2e377d2a7e
|
@ -2,8 +2,7 @@ import random
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, List
|
from typing import Optional, Self, List
|
||||||
from typing_extensions import Self
|
|
||||||
import collections, unittest
|
import collections, unittest
|
||||||
from math import log, sqrt
|
from math import log, sqrt
|
||||||
from pantograph.server import Server, TacticFailure, ServerError
|
from pantograph.server import Server, TacticFailure, ServerError
|
||||||
|
@ -12,25 +11,26 @@ from pantograph.expr import Expr, Tactic, GoalState
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SearchState:
|
class SearchState:
|
||||||
|
goal_state: GoalState
|
||||||
state: GoalState
|
parent: Optional[Self]
|
||||||
parent: Optional[int]
|
|
||||||
parent_goal_id: Optional[int]
|
parent_goal_id: Optional[int]
|
||||||
priorities: list[float]
|
priorities: list[float]
|
||||||
children: Optional[List[Self]] = None
|
children: Optional[List[Self]] = None
|
||||||
tested_tactics: Optional[List[Tactic]] = None
|
tested_tactics: Optional[List[Tactic]] = None
|
||||||
total_value: Optional[float] = None
|
total_value: Optional[float] = None
|
||||||
|
tactic_feedback: Optional[str] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
assert len(self.priorities) == len(self.state.goals)
|
assert len(self.priorities) == len(self.goal_state.goals)
|
||||||
self.solved = [False for _ in self.state.goals]
|
self.solved = [False for _ in self.goal_state.goals]
|
||||||
self.trials = [0 for _ in self.state.goals]
|
self.trials = [0 for _ in self.goal_state.goals]
|
||||||
self.tested_tactics = [] if self.tested_tactics is None else self.tested_tactics
|
self.tested_tactics = [] if self.tested_tactics is None else self.tested_tactics
|
||||||
self.children = [] if self.children is None else self.children
|
self.children = [] if self.children is None else self.children
|
||||||
self.visit_count = 1
|
self.visit_count = 1
|
||||||
self.exhausted = False
|
self.exhausted = False
|
||||||
self.subtree_exhausted = False
|
self.subtree_exhausted = False
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def next_goal_id(self) -> int:
|
def next_goal_id(self) -> int:
|
||||||
goal_id, _ = max(
|
goal_id, _ = max(
|
||||||
|
@ -58,7 +58,6 @@ class Agent:
|
||||||
"""
|
"""
|
||||||
An agent interface for proof search
|
An agent interface for proof search
|
||||||
"""
|
"""
|
||||||
tactic_feedback: Optional[str] = None
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def next_tactic(
|
def next_tactic(
|
||||||
|
@ -99,7 +98,7 @@ class Agent:
|
||||||
time_start = time.time()
|
time_start = time.time()
|
||||||
|
|
||||||
initial_state = SearchState(
|
initial_state = SearchState(
|
||||||
state=goal_state,
|
goal_state,
|
||||||
parent=None,
|
parent=None,
|
||||||
parent_goal_id=None,
|
parent_goal_id=None,
|
||||||
priorities=[0.0 for _ in goal_state.goals]
|
priorities=[0.0 for _ in goal_state.goals]
|
||||||
|
@ -130,13 +129,13 @@ class Agent:
|
||||||
tactic = None
|
tactic = None
|
||||||
else:
|
else:
|
||||||
# Generate tactic for this goal
|
# Generate tactic for this goal
|
||||||
tactic = self.next_tactic(search_state.state, goal_id)
|
tactic = self.next_tactic(search_state.goal_state, goal_id)
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"Next tactic: {tactic}")
|
print(f"Next tactic: {tactic}")
|
||||||
if not tactic:
|
if not tactic:
|
||||||
# resets the feedback
|
# resets the feedback
|
||||||
self.tactic_feedback = None
|
search_state.tactic_feedback = None
|
||||||
# pop the current state and continue to the next
|
# pop the current state and continue to the next
|
||||||
search_stack.pop(-1)
|
search_stack.pop(-1)
|
||||||
if not search_stack:
|
if not search_stack:
|
||||||
|
@ -153,18 +152,18 @@ class Agent:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
search_state.trials[goal_id] += 1
|
search_state.trials[goal_id] += 1
|
||||||
state = search_state.state
|
goal_state = search_state.goal_state
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"{state.state_id}.{goal_id}: {tactic} on {search_state.state.goals[goal_id]}")
|
print(f"{goal_state.state_id}.{goal_id}: {tactic} on {goal_state.goals[goal_id]}")
|
||||||
next_goal_state = server.goal_tactic(search_state.state, goal_id, tactic)
|
next_goal_state = server.goal_tactic(goal_state, goal_id, tactic)
|
||||||
# Generate priorities for the next goal state
|
# Generate priorities for the next goal state
|
||||||
priorities = [0.0 for _ in next_goal_state.goals] \
|
priorities = [0.0 for _ in next_goal_state.goals] \
|
||||||
if len(next_goal_state.goals) <= 1 else \
|
if len(next_goal_state.goals) <= 1 else \
|
||||||
self.guidance(next_goal_state)
|
self.guidance(next_goal_state)
|
||||||
parent = len(search_stack) - 1
|
parent = len(search_stack) - 1
|
||||||
next_state = SearchState(
|
next_state = SearchState(
|
||||||
state=next_goal_state,
|
goal_state=next_goal_state,
|
||||||
parent=parent,
|
parent=search_state,
|
||||||
parent_goal_id=goal_id,
|
parent_goal_id=goal_id,
|
||||||
priorities=priorities
|
priorities=priorities
|
||||||
)
|
)
|
||||||
|
@ -173,7 +172,7 @@ class Agent:
|
||||||
except TacticFailure as t:
|
except TacticFailure as t:
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"Tactic failed: {t}")
|
print(f"Tactic failed: {t}")
|
||||||
self.tactic_feedback = str(t)
|
search_state.tactic_feedback = str(t)
|
||||||
# try the next tactic. this one failed
|
# try the next tactic. this one failed
|
||||||
except ServerError as e:
|
except ServerError as e:
|
||||||
raise RuntimeError(f"While executing tactic: {tactic}") from e
|
raise RuntimeError(f"While executing tactic: {tactic}") from e
|
||||||
|
@ -250,7 +249,7 @@ class MCTSAgent(Agent):
|
||||||
time_start = time.time()
|
time_start = time.time()
|
||||||
|
|
||||||
initial_state = SearchState(
|
initial_state = SearchState(
|
||||||
state=goal_state,
|
goal_state=goal_state,
|
||||||
parent=None,
|
parent=None,
|
||||||
parent_goal_id=None,
|
parent_goal_id=None,
|
||||||
priorities=[0.0 for _ in goal_state.goals]
|
priorities=[0.0 for _ in goal_state.goals]
|
||||||
|
@ -279,13 +278,13 @@ class MCTSAgent(Agent):
|
||||||
tactic = None
|
tactic = None
|
||||||
else:
|
else:
|
||||||
# Generate tactic for this goal
|
# Generate tactic for this goal
|
||||||
tactic = self.next_tactic(search_state.state, goal_id, search_state.tested_tactics)
|
tactic = self.next_tactic(search_state.goal_state, goal_id, search_state.tested_tactics)
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"Next tactic: {tactic}")
|
print(f"Next tactic: {tactic}")
|
||||||
if not tactic:
|
if not tactic:
|
||||||
# resets the feedback
|
# resets the feedback
|
||||||
self.tactic_feedback = None
|
search_state.tactic_feedback = None
|
||||||
search_state.exhausted = True
|
search_state.exhausted = True
|
||||||
search_state.subtree_exhausted = all(child.subtree_exhausted for child in search_state.children)
|
search_state.subtree_exhausted = all(child.subtree_exhausted for child in search_state.children)
|
||||||
continue
|
continue
|
||||||
|
@ -294,17 +293,17 @@ class MCTSAgent(Agent):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
search_state.trials[goal_id] += 1
|
search_state.trials[goal_id] += 1
|
||||||
state = search_state.state
|
state = search_state.goal_state
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"{state.state_id}.{goal_id}: {tactic} on {search_state.state.goals[goal_id]}")
|
print(f"{state.state_id}.{goal_id}: {tactic} on {search_state.goal_state.goals[goal_id]}")
|
||||||
next_goal_state = server.goal_tactic(search_state.state, goal_id, tactic)
|
next_goal_state = server.goal_tactic(search_state.goal_state, goal_id, tactic)
|
||||||
# Generate priorities for the next goal state
|
# Generate priorities for the next goal state
|
||||||
priorities = [0.0 for _ in next_goal_state.goals] \
|
priorities = [0.0 for _ in next_goal_state.goals] \
|
||||||
if len(next_goal_state.goals) <= 1 else \
|
if len(next_goal_state.goals) <= 1 else \
|
||||||
self.guidance(next_goal_state)
|
self.guidance(next_goal_state)
|
||||||
parent = -1
|
parent = -1
|
||||||
next_state = SearchState(
|
next_state = SearchState(
|
||||||
state=next_goal_state,
|
goal_state=next_goal_state,
|
||||||
parent=parent,
|
parent=parent,
|
||||||
parent_goal_id=goal_id,
|
parent_goal_id=goal_id,
|
||||||
priorities=priorities
|
priorities=priorities
|
||||||
|
@ -315,7 +314,7 @@ class MCTSAgent(Agent):
|
||||||
except TacticFailure as t:
|
except TacticFailure as t:
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"Tactic failed: {t}")
|
print(f"Tactic failed: {t}")
|
||||||
self.tactic_feedback = str(t) # This should most likely be feedback per node
|
search_state.tactic_feedback = str(t)
|
||||||
# try the next tactic. this one failed
|
# try the next tactic. this one failed
|
||||||
except ServerError as e:
|
except ServerError as e:
|
||||||
raise RuntimeError(f"While executing tactic: {tactic}") from e
|
raise RuntimeError(f"While executing tactic: {tactic}") from e
|
||||||
|
|
Loading…
Reference in New Issue