Add mcts agent
This commit is contained in:
parent
70e2f2e83e
commit
c6358056fe
|
@ -1,9 +1,11 @@
|
||||||
|
import random
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional, List
|
||||||
|
from typing_extensions import Self
|
||||||
import collections, unittest
|
import collections, unittest
|
||||||
|
from math import log, sqrt
|
||||||
from pantograph.server import Server, TacticFailure, ServerError
|
from pantograph.server import Server, TacticFailure, ServerError
|
||||||
from pantograph.expr import Expr, Tactic, GoalState
|
from pantograph.expr import Expr, Tactic, GoalState
|
||||||
|
|
||||||
|
@ -15,11 +17,19 @@ class SearchState:
|
||||||
parent: Optional[int]
|
parent: Optional[int]
|
||||||
parent_goal_id: Optional[int]
|
parent_goal_id: Optional[int]
|
||||||
priorities: list[float]
|
priorities: list[float]
|
||||||
|
children: Optional[List[Self]] = None
|
||||||
|
tested_tactics: Optional[List[Tactic]] = None
|
||||||
|
total_value: Optional[float] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
assert len(self.priorities) == len(self.state.goals)
|
assert len(self.priorities) == len(self.state.goals)
|
||||||
self.solved = [False for _ in self.state.goals]
|
self.solved = [False for _ in self.state.goals]
|
||||||
self.trials = [0 for _ in self.state.goals]
|
self.trials = [0 for _ in self.state.goals]
|
||||||
|
self.tested_tactics = [] if self.tested_tactics is None else self.tested_tactics
|
||||||
|
self.children = [] if self.children is None else self.children
|
||||||
|
self.visit_count = 1
|
||||||
|
self.exhausted = False
|
||||||
|
self.subtree_exhausted = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def next_goal_id(self) -> int:
|
def next_goal_id(self) -> int:
|
||||||
|
@ -180,6 +190,148 @@ class Agent:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MCTSAgent(Agent):
|
||||||
|
"""
|
||||||
|
An agent interface for proof search using monte carlo tree search
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def next_tactic(
|
||||||
|
self,
|
||||||
|
state: GoalState,
|
||||||
|
goal_id: int,
|
||||||
|
tested: Optional[List[Tactic]] = None,
|
||||||
|
) -> Optional[Tactic]:
|
||||||
|
"""
|
||||||
|
Implement this function to generate the next tactic for a goal given tactics already tested
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def reset(self):
|
||||||
|
"""
|
||||||
|
Called after search
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def estimate(self, state: SearchState) -> SearchState:
|
||||||
|
"""
|
||||||
|
Implement this function to estimate the value of a state
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def select(self, state: SearchState) -> list[SearchState]:
|
||||||
|
"""
|
||||||
|
Implement this function to select the best node within the subtree of the state.
|
||||||
|
Returns the path to the selected node from the given state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def backup(self, states: list[SearchState], value: float):
|
||||||
|
"""
|
||||||
|
Backup value of the state at the end of the states list.
|
||||||
|
"""
|
||||||
|
for state in states:
|
||||||
|
state.total_value += value
|
||||||
|
state.visit_count += 1
|
||||||
|
state.subtree_exhausted = all(child.subtree_exhausted for child in state.children) and state.exhausted
|
||||||
|
|
||||||
|
def search(self,
|
||||||
|
server: Server,
|
||||||
|
goal_state: GoalState,
|
||||||
|
max_steps: int = 100,
|
||||||
|
max_trials_per_goal: int = 5,
|
||||||
|
verbose: bool = False) -> SearchResult:
|
||||||
|
"""
|
||||||
|
Executes proof search on this state
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert server.is_automatic(), "Search must be run in automatic mode"
|
||||||
|
|
||||||
|
n_goals_root = len(goal_state.goals)
|
||||||
|
time_start = time.time()
|
||||||
|
|
||||||
|
initial_state = SearchState(
|
||||||
|
state=goal_state,
|
||||||
|
parent=None,
|
||||||
|
parent_goal_id=None,
|
||||||
|
priorities=[0.0 for _ in goal_state.goals]
|
||||||
|
)
|
||||||
|
initial_state = self.estimate(initial_state)
|
||||||
|
search_root = initial_state
|
||||||
|
|
||||||
|
for i_step in range(max_steps):
|
||||||
|
search_trajectory = self.select(search_root)
|
||||||
|
search_state = search_trajectory[-1]
|
||||||
|
assert isinstance(search_state, SearchState)
|
||||||
|
|
||||||
|
if search_state.is_solved:
|
||||||
|
return SearchResult(
|
||||||
|
n_goals_root=n_goals_root,
|
||||||
|
duration=time.time() - time_start,
|
||||||
|
success=True,
|
||||||
|
steps=i_step,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Find the unsolved goal with the highest priority
|
||||||
|
goal_id = search_state.next_goal_id
|
||||||
|
|
||||||
|
if search_state.trials[goal_id] > max_trials_per_goal:
|
||||||
|
# force halt the search
|
||||||
|
tactic = None
|
||||||
|
else:
|
||||||
|
# Generate tactic for this goal
|
||||||
|
tactic = self.next_tactic(search_state.state, goal_id, search_state.tested_tactics)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(f"Next tactic: {tactic}")
|
||||||
|
if not tactic:
|
||||||
|
# resets the feedback
|
||||||
|
self.tactic_feedback = None
|
||||||
|
search_state.exhausted = True
|
||||||
|
search_state.subtree_exhausted = all(child.subtree_exhausted for child in search_state.children)
|
||||||
|
continue
|
||||||
|
assert tactic not in search_state.tested_tactics, "Tactic already seen!"
|
||||||
|
search_state.tested_tactics.append(tactic)
|
||||||
|
|
||||||
|
try:
|
||||||
|
search_state.trials[goal_id] += 1
|
||||||
|
state = search_state.state
|
||||||
|
if verbose:
|
||||||
|
print(f"{state.state_id}.{goal_id}: {tactic} on {search_state.state.goals[goal_id]}")
|
||||||
|
next_goal_state = server.goal_tactic(search_state.state, goal_id, tactic)
|
||||||
|
# Generate priorities for the next goal state
|
||||||
|
priorities = [0.0 for _ in next_goal_state.goals] \
|
||||||
|
if len(next_goal_state.goals) <= 1 else \
|
||||||
|
self.guidance(next_goal_state)
|
||||||
|
parent = -1
|
||||||
|
next_state = SearchState(
|
||||||
|
state=next_goal_state,
|
||||||
|
parent=parent,
|
||||||
|
parent_goal_id=goal_id,
|
||||||
|
priorities=priorities
|
||||||
|
)
|
||||||
|
next_state = self.estimate(next_state)
|
||||||
|
search_state.children.append(next_state)
|
||||||
|
self.backup(search_trajectory, next_state.total_value)
|
||||||
|
except TacticFailure as t:
|
||||||
|
if verbose:
|
||||||
|
print(f"Tactic failed: {t}")
|
||||||
|
self.tactic_feedback = str(t) # This should most likely be feedback per node
|
||||||
|
# try the next tactic. this one failed
|
||||||
|
except ServerError as e:
|
||||||
|
raise RuntimeError(f"While executing tactic: {tactic}") from e
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print("Search iteration limit exhausted")
|
||||||
|
|
||||||
|
self.reset()
|
||||||
|
return SearchResult(
|
||||||
|
n_goals_root=n_goals_root,
|
||||||
|
duration=time.time() - time_start,
|
||||||
|
success=False,
|
||||||
|
steps=max_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DumbAgent(Agent):
|
class DumbAgent(Agent):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -221,6 +373,79 @@ class DumbAgent(Agent):
|
||||||
self.goal_tactic_id_map[key] = i + 1
|
self.goal_tactic_id_map[key] = i + 1
|
||||||
return tactics[i]
|
return tactics[i]
|
||||||
|
|
||||||
|
class DumbMCTSAgent(MCTSAgent):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.goal_tactic_id_map = collections.defaultdict(lambda : 0)
|
||||||
|
self.intros = [
|
||||||
|
"intro",
|
||||||
|
]
|
||||||
|
self.tactics = [
|
||||||
|
"intro h",
|
||||||
|
"cases h",
|
||||||
|
"apply Or.inl",
|
||||||
|
"apply Or.inr",
|
||||||
|
]
|
||||||
|
self.no_space_tactics = [
|
||||||
|
"assumption",
|
||||||
|
]
|
||||||
|
self.c = 0.6
|
||||||
|
|
||||||
|
def estimate(self, state: SearchState) -> SearchState:
|
||||||
|
state.total_value = random.random()
|
||||||
|
return state
|
||||||
|
|
||||||
|
def select(self, state: SearchState) -> list[SearchState]:
|
||||||
|
"""
|
||||||
|
UCB scoring with taking the current state as one option, i.e. one child
|
||||||
|
"""
|
||||||
|
state_trajectory = [state]
|
||||||
|
current_state = state
|
||||||
|
current_state_ucb = (state.total_value / state.visit_count) + self.c * sqrt((log(state.visit_count) / state.visit_count))
|
||||||
|
while current_state.children:
|
||||||
|
avg_val = [child.total_value / child.visit_count for child in current_state.children]
|
||||||
|
visit_portions = [sqrt(log(current_state.visit_count) / child.visit_count) for child in current_state.children]
|
||||||
|
ucbs = [avg + self.c * visit for avg, visit in zip(avg_val, visit_portions, strict=True)]
|
||||||
|
child_idcs = [idx for idx in range(len(current_state.children)) if not current_state.children[idx].subtree_exhausted]
|
||||||
|
if not child_idcs:
|
||||||
|
return state_trajectory
|
||||||
|
child_idx = child_idcs[0]
|
||||||
|
for i in child_idcs:
|
||||||
|
if ucbs[i] > ucbs[child_idx]:
|
||||||
|
child_idx = i
|
||||||
|
if ucbs[child_idx] < current_state_ucb and not current_state.exhausted:
|
||||||
|
return state_trajectory
|
||||||
|
current_state_ucb = ucbs[child_idx]
|
||||||
|
current_state = current_state.children[child_idx]
|
||||||
|
state_trajectory.append(current_state)
|
||||||
|
return state_trajectory
|
||||||
|
|
||||||
|
def next_tactic(
|
||||||
|
self,
|
||||||
|
state: GoalState,
|
||||||
|
goal_id: int,
|
||||||
|
tested: Optional[List[Tactic]] = None
|
||||||
|
) -> Optional[Tactic]:
|
||||||
|
key = (state.state_id, goal_id)
|
||||||
|
i = self.goal_tactic_id_map[key]
|
||||||
|
target = state.goals[goal_id].target
|
||||||
|
if target.startswith('∀'):
|
||||||
|
tactics = self.intros
|
||||||
|
elif ' ' in target:
|
||||||
|
tactics = self.tactics
|
||||||
|
else:
|
||||||
|
tactics = self.no_space_tactics
|
||||||
|
|
||||||
|
if i >= len(tactics):
|
||||||
|
return None
|
||||||
|
self.goal_tactic_id_map[key] = i + 1
|
||||||
|
while tactics[i] in tested:
|
||||||
|
i += 1
|
||||||
|
if i >= len(tactics):
|
||||||
|
return None
|
||||||
|
return tactics[i]
|
||||||
|
|
||||||
|
|
||||||
class TestSearch(unittest.TestCase):
|
class TestSearch(unittest.TestCase):
|
||||||
|
|
||||||
|
@ -246,6 +471,31 @@ class TestSearch(unittest.TestCase):
|
||||||
verbose=False)
|
verbose=False)
|
||||||
self.assertTrue(flag)
|
self.assertTrue(flag)
|
||||||
|
|
||||||
|
class TestMCTSSearch(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_solve(self):
|
||||||
|
|
||||||
|
server = Server()
|
||||||
|
agent = DumbMCTSAgent()
|
||||||
|
goal_state = server.goal_start("∀ (p q: Prop), p -> p")
|
||||||
|
flag = agent.search(
|
||||||
|
server=server,
|
||||||
|
goal_state=goal_state,
|
||||||
|
verbose=False)
|
||||||
|
#flag = agent.search(server=server, target="∀ (p q: Prop), Or p q -> Or q p", verbose=True)
|
||||||
|
self.assertTrue(flag)
|
||||||
|
def test_solve_big(self):
|
||||||
|
|
||||||
|
server = Server()
|
||||||
|
agent = DumbMCTSAgent()
|
||||||
|
goal_state = server.goal_start("∀ (p q: Prop), Or p q -> Or q p")
|
||||||
|
flag = agent.search(
|
||||||
|
server=server,
|
||||||
|
goal_state=goal_state,
|
||||||
|
max_steps=200,
|
||||||
|
verbose=True)
|
||||||
|
self.assertTrue(flag)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
Loading…
Reference in New Issue