From c6358056fece51eb30042573993ef6e39086bafc Mon Sep 17 00:00:00 2001 From: Simon <80467011+sorgfresser@users.noreply.github.com> Date: Mon, 21 Oct 2024 17:36:20 +0100 Subject: [PATCH 1/3] Add mcts agent --- pantograph/search.py | 254 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 252 insertions(+), 2 deletions(-) diff --git a/pantograph/search.py b/pantograph/search.py index b6cd63e..eba37ab 100644 --- a/pantograph/search.py +++ b/pantograph/search.py @@ -1,9 +1,11 @@ +import random from abc import abstractmethod import time from dataclasses import dataclass -from typing import Optional +from typing import Optional, List +from typing_extensions import Self import collections, unittest - +from math import log, sqrt from pantograph.server import Server, TacticFailure, ServerError from pantograph.expr import Expr, Tactic, GoalState @@ -15,11 +17,19 @@ class SearchState: parent: Optional[int] parent_goal_id: Optional[int] priorities: list[float] + children: Optional[List[Self]] = None + tested_tactics: Optional[List[Tactic]] = None + total_value: Optional[float] = None def __post_init__(self): assert len(self.priorities) == len(self.state.goals) self.solved = [False for _ in self.state.goals] self.trials = [0 for _ in self.state.goals] + self.tested_tactics = [] if self.tested_tactics is None else self.tested_tactics + self.children = [] if self.children is None else self.children + self.visit_count = 1 + self.exhausted = False + self.subtree_exhausted = False @property def next_goal_id(self) -> int: @@ -180,6 +190,148 @@ class Agent: ) +class MCTSAgent(Agent): + """ + An agent interface for proof search using monte carlo tree search + """ + + @abstractmethod + def next_tactic( + self, + state: GoalState, + goal_id: int, + tested: Optional[List[Tactic]] = None, + ) -> Optional[Tactic]: + """ + Implement this function to generate the next tactic for a goal given tactics already tested + """ + + @abstractmethod + def reset(self): + """ + Called after search + """ + + @abstractmethod + def estimate(self, state: SearchState) -> SearchState: + """ + Implement this function to estimate the value of a state + """ + + @abstractmethod + def select(self, state: SearchState) -> list[SearchState]: + """ + Implement this function to select the best node within the subtree of the state. + Returns the path to the selected node from the given state. + """ + + def backup(self, states: list[SearchState], value: float): + """ + Backup value of the state at the end of the states list. + """ + for state in states: + state.total_value += value + state.visit_count += 1 + state.subtree_exhausted = all(child.subtree_exhausted for child in state.children) and state.exhausted + + def search(self, + server: Server, + goal_state: GoalState, + max_steps: int = 100, + max_trials_per_goal: int = 5, + verbose: bool = False) -> SearchResult: + """ + Executes proof search on this state + """ + + assert server.is_automatic(), "Search must be run in automatic mode" + + n_goals_root = len(goal_state.goals) + time_start = time.time() + + initial_state = SearchState( + state=goal_state, + parent=None, + parent_goal_id=None, + priorities=[0.0 for _ in goal_state.goals] + ) + initial_state = self.estimate(initial_state) + search_root = initial_state + + for i_step in range(max_steps): + search_trajectory = self.select(search_root) + search_state = search_trajectory[-1] + assert isinstance(search_state, SearchState) + + if search_state.is_solved: + return SearchResult( + n_goals_root=n_goals_root, + duration=time.time() - time_start, + success=True, + steps=i_step, + ) + + # Find the unsolved goal with the highest priority + goal_id = search_state.next_goal_id + + if search_state.trials[goal_id] > max_trials_per_goal: + # force halt the search + tactic = None + else: + # Generate tactic for this goal + tactic = self.next_tactic(search_state.state, goal_id, search_state.tested_tactics) + + if verbose: + print(f"Next tactic: {tactic}") + if not tactic: + # resets the feedback + self.tactic_feedback = None + search_state.exhausted = True + search_state.subtree_exhausted = all(child.subtree_exhausted for child in search_state.children) + continue + assert tactic not in search_state.tested_tactics, "Tactic already seen!" + search_state.tested_tactics.append(tactic) + + try: + search_state.trials[goal_id] += 1 + state = search_state.state + if verbose: + print(f"{state.state_id}.{goal_id}: {tactic} on {search_state.state.goals[goal_id]}") + next_goal_state = server.goal_tactic(search_state.state, goal_id, tactic) + # Generate priorities for the next goal state + priorities = [0.0 for _ in next_goal_state.goals] \ + if len(next_goal_state.goals) <= 1 else \ + self.guidance(next_goal_state) + parent = -1 + next_state = SearchState( + state=next_goal_state, + parent=parent, + parent_goal_id=goal_id, + priorities=priorities + ) + next_state = self.estimate(next_state) + search_state.children.append(next_state) + self.backup(search_trajectory, next_state.total_value) + except TacticFailure as t: + if verbose: + print(f"Tactic failed: {t}") + self.tactic_feedback = str(t) # This should most likely be feedback per node + # try the next tactic. this one failed + except ServerError as e: + raise RuntimeError(f"While executing tactic: {tactic}") from e + + if verbose: + print("Search iteration limit exhausted") + + self.reset() + return SearchResult( + n_goals_root=n_goals_root, + duration=time.time() - time_start, + success=False, + steps=max_steps, + ) + + class DumbAgent(Agent): def __init__(self): @@ -221,6 +373,79 @@ class DumbAgent(Agent): self.goal_tactic_id_map[key] = i + 1 return tactics[i] +class DumbMCTSAgent(MCTSAgent): + def __init__(self): + super().__init__() + + self.goal_tactic_id_map = collections.defaultdict(lambda : 0) + self.intros = [ + "intro", + ] + self.tactics = [ + "intro h", + "cases h", + "apply Or.inl", + "apply Or.inr", + ] + self.no_space_tactics = [ + "assumption", + ] + self.c = 0.6 + + def estimate(self, state: SearchState) -> SearchState: + state.total_value = random.random() + return state + + def select(self, state: SearchState) -> list[SearchState]: + """ + UCB scoring with taking the current state as one option, i.e. one child + """ + state_trajectory = [state] + current_state = state + current_state_ucb = (state.total_value / state.visit_count) + self.c * sqrt((log(state.visit_count) / state.visit_count)) + while current_state.children: + avg_val = [child.total_value / child.visit_count for child in current_state.children] + visit_portions = [sqrt(log(current_state.visit_count) / child.visit_count) for child in current_state.children] + ucbs = [avg + self.c * visit for avg, visit in zip(avg_val, visit_portions, strict=True)] + child_idcs = [idx for idx in range(len(current_state.children)) if not current_state.children[idx].subtree_exhausted] + if not child_idcs: + return state_trajectory + child_idx = child_idcs[0] + for i in child_idcs: + if ucbs[i] > ucbs[child_idx]: + child_idx = i + if ucbs[child_idx] < current_state_ucb and not current_state.exhausted: + return state_trajectory + current_state_ucb = ucbs[child_idx] + current_state = current_state.children[child_idx] + state_trajectory.append(current_state) + return state_trajectory + + def next_tactic( + self, + state: GoalState, + goal_id: int, + tested: Optional[List[Tactic]] = None + ) -> Optional[Tactic]: + key = (state.state_id, goal_id) + i = self.goal_tactic_id_map[key] + target = state.goals[goal_id].target + if target.startswith('∀'): + tactics = self.intros + elif ' ' in target: + tactics = self.tactics + else: + tactics = self.no_space_tactics + + if i >= len(tactics): + return None + self.goal_tactic_id_map[key] = i + 1 + while tactics[i] in tested: + i += 1 + if i >= len(tactics): + return None + return tactics[i] + class TestSearch(unittest.TestCase): @@ -246,6 +471,31 @@ class TestSearch(unittest.TestCase): verbose=False) self.assertTrue(flag) +class TestMCTSSearch(unittest.TestCase): + + def test_solve(self): + + server = Server() + agent = DumbMCTSAgent() + goal_state = server.goal_start("∀ (p q: Prop), p -> p") + flag = agent.search( + server=server, + goal_state=goal_state, + verbose=False) + #flag = agent.search(server=server, target="∀ (p q: Prop), Or p q -> Or q p", verbose=True) + self.assertTrue(flag) + def test_solve_big(self): + + server = Server() + agent = DumbMCTSAgent() + goal_state = server.goal_start("∀ (p q: Prop), Or p q -> Or q p") + flag = agent.search( + server=server, + goal_state=goal_state, + max_steps=200, + verbose=True) + self.assertTrue(flag) + if __name__ == '__main__': unittest.main() From 2e377d2a7ee4f22f6bdd9da7fa99dace9807419c Mon Sep 17 00:00:00 2001 From: Simon <80467011+sorgfresser@users.noreply.github.com> Date: Sat, 2 Nov 2024 10:35:57 +0000 Subject: [PATCH 2/3] feat: Tactic feedback per state --- pantograph/search.py | 51 ++++++++++++++++++++++---------------------- 1 file changed, 25 insertions(+), 26 deletions(-) diff --git a/pantograph/search.py b/pantograph/search.py index eba37ab..679223b 100644 --- a/pantograph/search.py +++ b/pantograph/search.py @@ -2,8 +2,7 @@ import random from abc import abstractmethod import time from dataclasses import dataclass -from typing import Optional, List -from typing_extensions import Self +from typing import Optional, Self, List import collections, unittest from math import log, sqrt from pantograph.server import Server, TacticFailure, ServerError @@ -12,25 +11,26 @@ from pantograph.expr import Expr, Tactic, GoalState @dataclass class SearchState: - - state: GoalState - parent: Optional[int] + goal_state: GoalState + parent: Optional[Self] parent_goal_id: Optional[int] priorities: list[float] children: Optional[List[Self]] = None tested_tactics: Optional[List[Tactic]] = None total_value: Optional[float] = None + tactic_feedback: Optional[str] = None def __post_init__(self): - assert len(self.priorities) == len(self.state.goals) - self.solved = [False for _ in self.state.goals] - self.trials = [0 for _ in self.state.goals] + assert len(self.priorities) == len(self.goal_state.goals) + self.solved = [False for _ in self.goal_state.goals] + self.trials = [0 for _ in self.goal_state.goals] self.tested_tactics = [] if self.tested_tactics is None else self.tested_tactics self.children = [] if self.children is None else self.children self.visit_count = 1 self.exhausted = False self.subtree_exhausted = False + @property def next_goal_id(self) -> int: goal_id, _ = max( @@ -58,7 +58,6 @@ class Agent: """ An agent interface for proof search """ - tactic_feedback: Optional[str] = None @abstractmethod def next_tactic( @@ -99,7 +98,7 @@ class Agent: time_start = time.time() initial_state = SearchState( - state=goal_state, + goal_state, parent=None, parent_goal_id=None, priorities=[0.0 for _ in goal_state.goals] @@ -130,13 +129,13 @@ class Agent: tactic = None else: # Generate tactic for this goal - tactic = self.next_tactic(search_state.state, goal_id) + tactic = self.next_tactic(search_state.goal_state, goal_id) if verbose: print(f"Next tactic: {tactic}") if not tactic: # resets the feedback - self.tactic_feedback = None + search_state.tactic_feedback = None # pop the current state and continue to the next search_stack.pop(-1) if not search_stack: @@ -153,18 +152,18 @@ class Agent: try: search_state.trials[goal_id] += 1 - state = search_state.state + goal_state = search_state.goal_state if verbose: - print(f"{state.state_id}.{goal_id}: {tactic} on {search_state.state.goals[goal_id]}") - next_goal_state = server.goal_tactic(search_state.state, goal_id, tactic) + print(f"{goal_state.state_id}.{goal_id}: {tactic} on {goal_state.goals[goal_id]}") + next_goal_state = server.goal_tactic(goal_state, goal_id, tactic) # Generate priorities for the next goal state priorities = [0.0 for _ in next_goal_state.goals] \ if len(next_goal_state.goals) <= 1 else \ self.guidance(next_goal_state) parent = len(search_stack) - 1 next_state = SearchState( - state=next_goal_state, - parent=parent, + goal_state=next_goal_state, + parent=search_state, parent_goal_id=goal_id, priorities=priorities ) @@ -173,7 +172,7 @@ class Agent: except TacticFailure as t: if verbose: print(f"Tactic failed: {t}") - self.tactic_feedback = str(t) + search_state.tactic_feedback = str(t) # try the next tactic. this one failed except ServerError as e: raise RuntimeError(f"While executing tactic: {tactic}") from e @@ -250,7 +249,7 @@ class MCTSAgent(Agent): time_start = time.time() initial_state = SearchState( - state=goal_state, + goal_state=goal_state, parent=None, parent_goal_id=None, priorities=[0.0 for _ in goal_state.goals] @@ -279,13 +278,13 @@ class MCTSAgent(Agent): tactic = None else: # Generate tactic for this goal - tactic = self.next_tactic(search_state.state, goal_id, search_state.tested_tactics) + tactic = self.next_tactic(search_state.goal_state, goal_id, search_state.tested_tactics) if verbose: print(f"Next tactic: {tactic}") if not tactic: # resets the feedback - self.tactic_feedback = None + search_state.tactic_feedback = None search_state.exhausted = True search_state.subtree_exhausted = all(child.subtree_exhausted for child in search_state.children) continue @@ -294,17 +293,17 @@ class MCTSAgent(Agent): try: search_state.trials[goal_id] += 1 - state = search_state.state + state = search_state.goal_state if verbose: - print(f"{state.state_id}.{goal_id}: {tactic} on {search_state.state.goals[goal_id]}") - next_goal_state = server.goal_tactic(search_state.state, goal_id, tactic) + print(f"{state.state_id}.{goal_id}: {tactic} on {search_state.goal_state.goals[goal_id]}") + next_goal_state = server.goal_tactic(search_state.goal_state, goal_id, tactic) # Generate priorities for the next goal state priorities = [0.0 for _ in next_goal_state.goals] \ if len(next_goal_state.goals) <= 1 else \ self.guidance(next_goal_state) parent = -1 next_state = SearchState( - state=next_goal_state, + goal_state=next_goal_state, parent=parent, parent_goal_id=goal_id, priorities=priorities @@ -315,7 +314,7 @@ class MCTSAgent(Agent): except TacticFailure as t: if verbose: print(f"Tactic failed: {t}") - self.tactic_feedback = str(t) # This should most likely be feedback per node + search_state.tactic_feedback = str(t) # try the next tactic. this one failed except ServerError as e: raise RuntimeError(f"While executing tactic: {tactic}") from e From 0ba6927a1ef7e838a48f1b0d3269912e79930c4e Mon Sep 17 00:00:00 2001 From: Simon <80467011+sorgfresser@users.noreply.github.com> Date: Sat, 2 Nov 2024 19:17:36 +0000 Subject: [PATCH 3/3] fix: Set verbosity to False on MCTS tests --- pantograph/search.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pantograph/search.py b/pantograph/search.py index 75ef85c..a04ae32 100644 --- a/pantograph/search.py +++ b/pantograph/search.py @@ -492,7 +492,7 @@ class TestMCTSSearch(unittest.TestCase): server=server, goal_state=goal_state, max_steps=200, - verbose=True) + verbose=False) self.assertTrue(flag)