import random from abc import abstractmethod import time from dataclasses import dataclass from typing import Optional, List from typing_extensions import Self import collections, unittest from math import log, sqrt from pantograph.server import Server, TacticFailure, ServerError from pantograph.expr import Expr, Tactic, GoalState @dataclass class SearchState: state: GoalState parent: Optional[int] parent_goal_id: Optional[int] priorities: list[float] children: Optional[List[Self]] = None tested_tactics: Optional[List[Tactic]] = None total_value: Optional[float] = None def __post_init__(self): assert len(self.priorities) == len(self.state.goals) self.solved = [False for _ in self.state.goals] self.trials = [0 for _ in self.state.goals] self.tested_tactics = [] if self.tested_tactics is None else self.tested_tactics self.children = [] if self.children is None else self.children self.visit_count = 1 self.exhausted = False self.subtree_exhausted = False @property def next_goal_id(self) -> int: goal_id, _ = max( ((i, prio) for i, prio in enumerate(self.priorities) if not self.solved[i]), key=lambda x: x[1]) return goal_id @property def is_root(self) -> bool: return self.parent is None @property def is_solved(self) -> bool: return all(self.solved) @dataclass(frozen=True) class SearchResult: n_goals_root: int duration: float success: bool steps: int class Agent: """ An agent interface for proof search """ tactic_feedback: Optional[str] = None @abstractmethod def next_tactic( self, state: GoalState, goal_id: int, ) -> Optional[Tactic]: """ Implement this function to generate the next tactic for a goal """ @abstractmethod def guidance(self, state: GoalState) -> list[float]: """ Return a list of priorities determining which goal should be searched first. This will not be called on states with one or zero goals. """ return [0.0 for _ in state.goals] @abstractmethod def reset(self): """ Called after search """ def search(self, server: Server, goal_state: GoalState, max_steps: int = 100, max_trials_per_goal: int = 5, verbose: bool = False) -> SearchResult: """ Executes proof search on this state """ assert server.is_automatic(), "Search must be run in automatic mode" n_goals_root = len(goal_state.goals) time_start = time.time() initial_state = SearchState( state=goal_state, parent=None, parent_goal_id=None, priorities=[0.0 for _ in goal_state.goals] ) search_stack = [initial_state] for i_step in range(max_steps): assert search_stack, "No states in search stack" if verbose: print(f"I={i_step}: len(S) = {len(search_stack)}") search_state = search_stack[-1] assert isinstance(search_state, SearchState) if search_state.is_solved: return SearchResult( n_goals_root=n_goals_root, duration=time.time() - time_start, success=True, steps=i_step, ) # Find the unsolved goal with the highest priority goal_id = search_state.next_goal_id if search_state.trials[goal_id] > max_trials_per_goal: # force halt the search tactic = None else: # Generate tactic for this goal tactic = self.next_tactic(search_state.state, goal_id) if verbose: print(f"Next tactic: {tactic}") if not tactic: # resets the feedback self.tactic_feedback = None # pop the current state and continue to the next search_stack.pop(-1) if not search_stack: if verbose: print("Search stack has been exhausted") self.reset() return SearchResult( n_goals_root=n_goals_root, duration=time.time() - time_start, success=False, steps=i_step, ) continue try: search_state.trials[goal_id] += 1 state = search_state.state if verbose: print(f"{state.state_id}.{goal_id}: {tactic} on {search_state.state.goals[goal_id]}") next_goal_state = server.goal_tactic(search_state.state, goal_id, tactic) # Generate priorities for the next goal state priorities = [0.0 for _ in next_goal_state.goals] \ if len(next_goal_state.goals) <= 1 else \ self.guidance(next_goal_state) parent = len(search_stack) - 1 next_state = SearchState( state=next_goal_state, parent=parent, parent_goal_id=goal_id, priorities=priorities ) search_stack.append(next_state) except TacticFailure as t: if verbose: print(f"Tactic failed: {t}") self.tactic_feedback = str(t) # try the next tactic. this one failed except ServerError as e: raise RuntimeError(f"While executing tactic: {tactic}") from e if verbose: print("Search iteration limit exhausted") self.reset() return SearchResult( n_goals_root=n_goals_root, duration=time.time() - time_start, success=False, steps=max_steps, ) class MCTSAgent(Agent): """ An agent interface for proof search using monte carlo tree search """ @abstractmethod def next_tactic( self, state: GoalState, goal_id: int, tested: Optional[List[Tactic]] = None, ) -> Optional[Tactic]: """ Implement this function to generate the next tactic for a goal given tactics already tested """ @abstractmethod def reset(self): """ Called after search """ @abstractmethod def estimate(self, state: SearchState) -> SearchState: """ Implement this function to estimate the value of a state """ @abstractmethod def select(self, state: SearchState) -> list[SearchState]: """ Implement this function to select the best node within the subtree of the state. Returns the path to the selected node from the given state. """ def backup(self, states: list[SearchState], value: float): """ Backup value of the state at the end of the states list. """ for state in states: state.total_value += value state.visit_count += 1 state.subtree_exhausted = all(child.subtree_exhausted for child in state.children) and state.exhausted def search(self, server: Server, goal_state: GoalState, max_steps: int = 100, max_trials_per_goal: int = 5, verbose: bool = False) -> SearchResult: """ Executes proof search on this state """ assert server.is_automatic(), "Search must be run in automatic mode" n_goals_root = len(goal_state.goals) time_start = time.time() initial_state = SearchState( state=goal_state, parent=None, parent_goal_id=None, priorities=[0.0 for _ in goal_state.goals] ) initial_state = self.estimate(initial_state) search_root = initial_state for i_step in range(max_steps): search_trajectory = self.select(search_root) search_state = search_trajectory[-1] assert isinstance(search_state, SearchState) if search_state.is_solved: return SearchResult( n_goals_root=n_goals_root, duration=time.time() - time_start, success=True, steps=i_step, ) # Find the unsolved goal with the highest priority goal_id = search_state.next_goal_id if search_state.trials[goal_id] > max_trials_per_goal: # force halt the search tactic = None else: # Generate tactic for this goal tactic = self.next_tactic(search_state.state, goal_id, search_state.tested_tactics) if verbose: print(f"Next tactic: {tactic}") if not tactic: # resets the feedback self.tactic_feedback = None search_state.exhausted = True search_state.subtree_exhausted = all(child.subtree_exhausted for child in search_state.children) continue assert tactic not in search_state.tested_tactics, "Tactic already seen!" search_state.tested_tactics.append(tactic) try: search_state.trials[goal_id] += 1 state = search_state.state if verbose: print(f"{state.state_id}.{goal_id}: {tactic} on {search_state.state.goals[goal_id]}") next_goal_state = server.goal_tactic(search_state.state, goal_id, tactic) # Generate priorities for the next goal state priorities = [0.0 for _ in next_goal_state.goals] \ if len(next_goal_state.goals) <= 1 else \ self.guidance(next_goal_state) parent = -1 next_state = SearchState( state=next_goal_state, parent=parent, parent_goal_id=goal_id, priorities=priorities ) next_state = self.estimate(next_state) search_state.children.append(next_state) self.backup(search_trajectory, next_state.total_value) except TacticFailure as t: if verbose: print(f"Tactic failed: {t}") self.tactic_feedback = str(t) # This should most likely be feedback per node # try the next tactic. this one failed except ServerError as e: raise RuntimeError(f"While executing tactic: {tactic}") from e if verbose: print("Search iteration limit exhausted") self.reset() return SearchResult( n_goals_root=n_goals_root, duration=time.time() - time_start, success=False, steps=max_steps, ) class DumbAgent(Agent): def __init__(self): super().__init__() self.goal_tactic_id_map = collections.defaultdict(lambda : 0) self.intros = [ "intro", ] self.tactics = [ "intro h", "cases h", "apply Or.inl", "apply Or.inr", ] self.no_space_tactics = [ "assumption", ] def next_tactic( self, state: GoalState, goal_id: int, ) -> Optional[Tactic]: key = (state.state_id, goal_id) i = self.goal_tactic_id_map[key] target = state.goals[goal_id].target if target.startswith('∀'): tactics = self.intros elif ' ' in target: tactics = self.tactics else: tactics = self.no_space_tactics if i >= len(tactics): return None self.goal_tactic_id_map[key] = i + 1 return tactics[i] class DumbMCTSAgent(MCTSAgent): def __init__(self): super().__init__() self.goal_tactic_id_map = collections.defaultdict(lambda : 0) self.intros = [ "intro", ] self.tactics = [ "intro h", "cases h", "apply Or.inl", "apply Or.inr", ] self.no_space_tactics = [ "assumption", ] self.c = 0.6 def estimate(self, state: SearchState) -> SearchState: state.total_value = random.random() return state def select(self, state: SearchState) -> list[SearchState]: """ UCB scoring with taking the current state as one option, i.e. one child """ state_trajectory = [state] current_state = state current_state_ucb = (state.total_value / state.visit_count) + self.c * sqrt((log(state.visit_count) / state.visit_count)) while current_state.children: avg_val = [child.total_value / child.visit_count for child in current_state.children] visit_portions = [sqrt(log(current_state.visit_count) / child.visit_count) for child in current_state.children] ucbs = [avg + self.c * visit for avg, visit in zip(avg_val, visit_portions, strict=True)] child_idcs = [idx for idx in range(len(current_state.children)) if not current_state.children[idx].subtree_exhausted] if not child_idcs: return state_trajectory child_idx = child_idcs[0] for i in child_idcs: if ucbs[i] > ucbs[child_idx]: child_idx = i if ucbs[child_idx] < current_state_ucb and not current_state.exhausted: return state_trajectory current_state_ucb = ucbs[child_idx] current_state = current_state.children[child_idx] state_trajectory.append(current_state) return state_trajectory def next_tactic( self, state: GoalState, goal_id: int, tested: Optional[List[Tactic]] = None ) -> Optional[Tactic]: key = (state.state_id, goal_id) i = self.goal_tactic_id_map[key] target = state.goals[goal_id].target if target.startswith('∀'): tactics = self.intros elif ' ' in target: tactics = self.tactics else: tactics = self.no_space_tactics if i >= len(tactics): return None self.goal_tactic_id_map[key] = i + 1 while tactics[i] in tested: i += 1 if i >= len(tactics): return None return tactics[i] class TestSearch(unittest.TestCase): def test_solve(self): server = Server() agent = DumbAgent() goal_state = server.goal_start("∀ (p q: Prop), p -> p") flag = agent.search( server=server, goal_state=goal_state, verbose=False) #flag = agent.search(server=server, target="∀ (p q: Prop), Or p q -> Or q p", verbose=True) self.assertTrue(flag) def test_solve_big(self): server = Server() agent = DumbAgent() goal_state = server.goal_start("∀ (p q: Prop), Or p q -> Or q p") flag = agent.search( server=server, goal_state=goal_state, verbose=False) self.assertTrue(flag) class TestMCTSSearch(unittest.TestCase): def test_solve(self): server = Server() agent = DumbMCTSAgent() goal_state = server.goal_start("∀ (p q: Prop), p -> p") flag = agent.search( server=server, goal_state=goal_state, verbose=False) #flag = agent.search(server=server, target="∀ (p q: Prop), Or p q -> Or q p", verbose=True) self.assertTrue(flag) def test_solve_big(self): server = Server() agent = DumbMCTSAgent() goal_state = server.goal_start("∀ (p q: Prop), Or p q -> Or q p") flag = agent.search( server=server, goal_state=goal_state, max_steps=200, verbose=True) self.assertTrue(flag) if __name__ == '__main__': unittest.main()