Add mcts agent

This commit is contained in:
Simon 2024-10-21 17:36:20 +01:00 committed by GitHub
parent 70e2f2e83e
commit c6358056fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 252 additions and 2 deletions

View File

@ -1,9 +1,11 @@
import random
from abc import abstractmethod from abc import abstractmethod
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional, List
from typing_extensions import Self
import collections, unittest import collections, unittest
from math import log, sqrt
from pantograph.server import Server, TacticFailure, ServerError from pantograph.server import Server, TacticFailure, ServerError
from pantograph.expr import Expr, Tactic, GoalState from pantograph.expr import Expr, Tactic, GoalState
@ -15,11 +17,19 @@ class SearchState:
parent: Optional[int] parent: Optional[int]
parent_goal_id: Optional[int] parent_goal_id: Optional[int]
priorities: list[float] priorities: list[float]
children: Optional[List[Self]] = None
tested_tactics: Optional[List[Tactic]] = None
total_value: Optional[float] = None
def __post_init__(self): def __post_init__(self):
assert len(self.priorities) == len(self.state.goals) assert len(self.priorities) == len(self.state.goals)
self.solved = [False for _ in self.state.goals] self.solved = [False for _ in self.state.goals]
self.trials = [0 for _ in self.state.goals] self.trials = [0 for _ in self.state.goals]
self.tested_tactics = [] if self.tested_tactics is None else self.tested_tactics
self.children = [] if self.children is None else self.children
self.visit_count = 1
self.exhausted = False
self.subtree_exhausted = False
@property @property
def next_goal_id(self) -> int: def next_goal_id(self) -> int:
@ -180,6 +190,148 @@ class Agent:
) )
class MCTSAgent(Agent):
"""
An agent interface for proof search using monte carlo tree search
"""
@abstractmethod
def next_tactic(
self,
state: GoalState,
goal_id: int,
tested: Optional[List[Tactic]] = None,
) -> Optional[Tactic]:
"""
Implement this function to generate the next tactic for a goal given tactics already tested
"""
@abstractmethod
def reset(self):
"""
Called after search
"""
@abstractmethod
def estimate(self, state: SearchState) -> SearchState:
"""
Implement this function to estimate the value of a state
"""
@abstractmethod
def select(self, state: SearchState) -> list[SearchState]:
"""
Implement this function to select the best node within the subtree of the state.
Returns the path to the selected node from the given state.
"""
def backup(self, states: list[SearchState], value: float):
"""
Backup value of the state at the end of the states list.
"""
for state in states:
state.total_value += value
state.visit_count += 1
state.subtree_exhausted = all(child.subtree_exhausted for child in state.children) and state.exhausted
def search(self,
server: Server,
goal_state: GoalState,
max_steps: int = 100,
max_trials_per_goal: int = 5,
verbose: bool = False) -> SearchResult:
"""
Executes proof search on this state
"""
assert server.is_automatic(), "Search must be run in automatic mode"
n_goals_root = len(goal_state.goals)
time_start = time.time()
initial_state = SearchState(
state=goal_state,
parent=None,
parent_goal_id=None,
priorities=[0.0 for _ in goal_state.goals]
)
initial_state = self.estimate(initial_state)
search_root = initial_state
for i_step in range(max_steps):
search_trajectory = self.select(search_root)
search_state = search_trajectory[-1]
assert isinstance(search_state, SearchState)
if search_state.is_solved:
return SearchResult(
n_goals_root=n_goals_root,
duration=time.time() - time_start,
success=True,
steps=i_step,
)
# Find the unsolved goal with the highest priority
goal_id = search_state.next_goal_id
if search_state.trials[goal_id] > max_trials_per_goal:
# force halt the search
tactic = None
else:
# Generate tactic for this goal
tactic = self.next_tactic(search_state.state, goal_id, search_state.tested_tactics)
if verbose:
print(f"Next tactic: {tactic}")
if not tactic:
# resets the feedback
self.tactic_feedback = None
search_state.exhausted = True
search_state.subtree_exhausted = all(child.subtree_exhausted for child in search_state.children)
continue
assert tactic not in search_state.tested_tactics, "Tactic already seen!"
search_state.tested_tactics.append(tactic)
try:
search_state.trials[goal_id] += 1
state = search_state.state
if verbose:
print(f"{state.state_id}.{goal_id}: {tactic} on {search_state.state.goals[goal_id]}")
next_goal_state = server.goal_tactic(search_state.state, goal_id, tactic)
# Generate priorities for the next goal state
priorities = [0.0 for _ in next_goal_state.goals] \
if len(next_goal_state.goals) <= 1 else \
self.guidance(next_goal_state)
parent = -1
next_state = SearchState(
state=next_goal_state,
parent=parent,
parent_goal_id=goal_id,
priorities=priorities
)
next_state = self.estimate(next_state)
search_state.children.append(next_state)
self.backup(search_trajectory, next_state.total_value)
except TacticFailure as t:
if verbose:
print(f"Tactic failed: {t}")
self.tactic_feedback = str(t) # This should most likely be feedback per node
# try the next tactic. this one failed
except ServerError as e:
raise RuntimeError(f"While executing tactic: {tactic}") from e
if verbose:
print("Search iteration limit exhausted")
self.reset()
return SearchResult(
n_goals_root=n_goals_root,
duration=time.time() - time_start,
success=False,
steps=max_steps,
)
class DumbAgent(Agent): class DumbAgent(Agent):
def __init__(self): def __init__(self):
@ -221,6 +373,79 @@ class DumbAgent(Agent):
self.goal_tactic_id_map[key] = i + 1 self.goal_tactic_id_map[key] = i + 1
return tactics[i] return tactics[i]
class DumbMCTSAgent(MCTSAgent):
def __init__(self):
super().__init__()
self.goal_tactic_id_map = collections.defaultdict(lambda : 0)
self.intros = [
"intro",
]
self.tactics = [
"intro h",
"cases h",
"apply Or.inl",
"apply Or.inr",
]
self.no_space_tactics = [
"assumption",
]
self.c = 0.6
def estimate(self, state: SearchState) -> SearchState:
state.total_value = random.random()
return state
def select(self, state: SearchState) -> list[SearchState]:
"""
UCB scoring with taking the current state as one option, i.e. one child
"""
state_trajectory = [state]
current_state = state
current_state_ucb = (state.total_value / state.visit_count) + self.c * sqrt((log(state.visit_count) / state.visit_count))
while current_state.children:
avg_val = [child.total_value / child.visit_count for child in current_state.children]
visit_portions = [sqrt(log(current_state.visit_count) / child.visit_count) for child in current_state.children]
ucbs = [avg + self.c * visit for avg, visit in zip(avg_val, visit_portions, strict=True)]
child_idcs = [idx for idx in range(len(current_state.children)) if not current_state.children[idx].subtree_exhausted]
if not child_idcs:
return state_trajectory
child_idx = child_idcs[0]
for i in child_idcs:
if ucbs[i] > ucbs[child_idx]:
child_idx = i
if ucbs[child_idx] < current_state_ucb and not current_state.exhausted:
return state_trajectory
current_state_ucb = ucbs[child_idx]
current_state = current_state.children[child_idx]
state_trajectory.append(current_state)
return state_trajectory
def next_tactic(
self,
state: GoalState,
goal_id: int,
tested: Optional[List[Tactic]] = None
) -> Optional[Tactic]:
key = (state.state_id, goal_id)
i = self.goal_tactic_id_map[key]
target = state.goals[goal_id].target
if target.startswith(''):
tactics = self.intros
elif ' ' in target:
tactics = self.tactics
else:
tactics = self.no_space_tactics
if i >= len(tactics):
return None
self.goal_tactic_id_map[key] = i + 1
while tactics[i] in tested:
i += 1
if i >= len(tactics):
return None
return tactics[i]
class TestSearch(unittest.TestCase): class TestSearch(unittest.TestCase):
@ -246,6 +471,31 @@ class TestSearch(unittest.TestCase):
verbose=False) verbose=False)
self.assertTrue(flag) self.assertTrue(flag)
class TestMCTSSearch(unittest.TestCase):
def test_solve(self):
server = Server()
agent = DumbMCTSAgent()
goal_state = server.goal_start("∀ (p q: Prop), p -> p")
flag = agent.search(
server=server,
goal_state=goal_state,
verbose=False)
#flag = agent.search(server=server, target="∀ (p q: Prop), Or p q -> Or q p", verbose=True)
self.assertTrue(flag)
def test_solve_big(self):
server = Server()
agent = DumbMCTSAgent()
goal_state = server.goal_start("∀ (p q: Prop), Or p q -> Or q p")
flag = agent.search(
server=server,
goal_state=goal_state,
max_steps=200,
verbose=True)
self.assertTrue(flag)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()