Pantograph/pantograph/search.py

501 lines
16 KiB
Python
Raw Normal View History

2024-10-21 09:36:20 -07:00
import random
from abc import abstractmethod
2024-10-07 18:58:35 -07:00
import time
2024-06-03 21:52:43 -07:00
from dataclasses import dataclass
2024-11-02 03:35:57 -07:00
from typing import Optional, Self, List
2024-06-03 21:52:43 -07:00
import collections, unittest
2024-10-21 09:36:20 -07:00
from math import log, sqrt
2024-10-08 17:59:48 -07:00
from pantograph.server import Server, TacticFailure, ServerError
2024-06-03 23:57:48 -07:00
from pantograph.expr import Expr, Tactic, GoalState
2024-06-03 21:52:43 -07:00
@dataclass
class SearchState:
2024-11-02 03:35:57 -07:00
goal_state: GoalState
parent: Optional[Self]
2024-06-03 21:52:43 -07:00
parent_goal_id: Optional[int]
priorities: list[float]
2024-10-21 09:36:20 -07:00
children: Optional[List[Self]] = None
tested_tactics: Optional[List[Tactic]] = None
total_value: Optional[float] = None
2024-11-02 03:35:57 -07:00
tactic_feedback: Optional[str] = None
2024-06-03 21:52:43 -07:00
def __post_init__(self):
2024-11-02 03:35:57 -07:00
assert len(self.priorities) == len(self.goal_state.goals)
self.solved = [False for _ in self.goal_state.goals]
self.trials = [0 for _ in self.goal_state.goals]
2024-10-21 09:36:20 -07:00
self.tested_tactics = [] if self.tested_tactics is None else self.tested_tactics
self.children = [] if self.children is None else self.children
self.visit_count = 1
self.exhausted = False
self.subtree_exhausted = False
2024-06-05 14:36:51 -07:00
@property
def next_goal_id(self) -> int:
goal_id, _ = max(
((i, prio) for i, prio in enumerate(self.priorities) if not self.solved[i]),
key=lambda x: x[1])
2024-06-05 14:36:51 -07:00
return goal_id
2024-06-03 21:52:43 -07:00
@property
def is_root(self) -> bool:
return self.parent is None
@property
def is_solved(self) -> bool:
return all(self.solved)
2024-06-05 14:19:18 -07:00
@dataclass(frozen=True)
class SearchResult:
2024-10-07 18:58:35 -07:00
n_goals_root: int
duration: float
2024-06-05 14:19:18 -07:00
success: bool
steps: int
2024-06-03 21:52:43 -07:00
class Agent:
"""
An agent interface for proof search
"""
@abstractmethod
def next_tactic(
self,
state: GoalState,
goal_id: int,
) -> Optional[Tactic]:
2024-06-03 21:52:43 -07:00
"""
Implement this function to generate the next tactic for a goal
"""
@abstractmethod
2024-06-03 21:52:43 -07:00
def guidance(self, state: GoalState) -> list[float]:
"""
Return a list of priorities determining which goal should be searched
first. This will not be called on states with one or zero goals.
"""
return [0.0 for _ in state.goals]
@abstractmethod
2024-06-03 21:52:43 -07:00
def reset(self):
"""
Called after search
"""
def search(self,
server: Server,
goal_state: GoalState,
2024-06-05 14:36:51 -07:00
max_steps: int = 100,
max_trials_per_goal: int = 5,
2024-06-05 14:19:18 -07:00
verbose: bool = False) -> SearchResult:
"""
2024-10-07 18:58:35 -07:00
Executes proof search on this state
"""
2024-10-07 18:58:35 -07:00
assert server.is_automatic(), "Search must be run in automatic mode"
2024-10-07 18:58:35 -07:00
n_goals_root = len(goal_state.goals)
time_start = time.time()
initial_state = SearchState(
2024-11-02 03:35:57 -07:00
goal_state,
parent=None,
parent_goal_id=None,
priorities=[0.0 for _ in goal_state.goals]
)
search_stack = [initial_state]
2024-06-03 21:52:43 -07:00
for i_step in range(max_steps):
assert search_stack, "No states in search stack"
if verbose:
print(f"I={i_step}: len(S) = {len(search_stack)}")
search_state = search_stack[-1]
assert isinstance(search_state, SearchState)
if search_state.is_solved:
2024-10-07 18:58:35 -07:00
return SearchResult(
n_goals_root=n_goals_root,
duration=time.time() - time_start,
success=True,
steps=i_step,
)
2024-06-03 21:52:43 -07:00
# Find the unsolved goal with the highest priority
2024-06-05 14:36:51 -07:00
goal_id = search_state.next_goal_id
if search_state.trials[goal_id] > max_trials_per_goal:
2024-06-05 14:36:51 -07:00
# force halt the search
tactic = None
else:
# Generate tactic for this goal
2024-11-02 03:35:57 -07:00
tactic = self.next_tactic(search_state.goal_state, goal_id)
2024-06-03 21:52:43 -07:00
if verbose:
print(f"Next tactic: {tactic}")
2024-06-03 21:52:43 -07:00
if not tactic:
# resets the feedback
2024-11-02 03:35:57 -07:00
search_state.tactic_feedback = None
2024-06-03 21:52:43 -07:00
# pop the current state and continue to the next
search_stack.pop(-1)
if not search_stack:
if verbose:
print("Search stack has been exhausted")
2024-06-03 21:52:43 -07:00
self.reset()
2024-10-07 18:58:35 -07:00
return SearchResult(
n_goals_root=n_goals_root,
duration=time.time() - time_start,
success=False,
steps=i_step,
)
2024-06-03 21:52:43 -07:00
continue
try:
search_state.trials[goal_id] += 1
2024-11-02 03:35:57 -07:00
goal_state = search_state.goal_state
2024-06-03 21:52:43 -07:00
if verbose:
2024-11-02 03:35:57 -07:00
print(f"{goal_state.state_id}.{goal_id}: {tactic} on {goal_state.goals[goal_id]}")
next_goal_state = server.goal_tactic(goal_state, goal_id, tactic)
2024-06-03 21:52:43 -07:00
# Generate priorities for the next goal state
priorities = [0.0 for _ in next_goal_state.goals] \
if len(next_goal_state.goals) <= 1 else \
self.guidance(next_goal_state)
parent = len(search_stack) - 1
next_state = SearchState(
2024-11-02 03:35:57 -07:00
goal_state=next_goal_state,
parent=search_state,
parent_goal_id=goal_id,
priorities=priorities
)
search_stack.append(next_state)
2024-06-03 21:52:43 -07:00
except TacticFailure as t:
if verbose:
print(f"Tactic failed: {t}")
2024-11-02 03:35:57 -07:00
search_state.tactic_feedback = str(t)
2024-06-03 21:52:43 -07:00
# try the next tactic. this one failed
2024-10-08 17:59:48 -07:00
except ServerError as e:
raise RuntimeError(f"While executing tactic: {tactic}") from e
2024-06-03 21:52:43 -07:00
if verbose:
print("Search iteration limit exhausted")
self.reset()
2024-10-07 18:58:35 -07:00
return SearchResult(
n_goals_root=n_goals_root,
duration=time.time() - time_start,
success=False,
steps=max_steps,
)
2024-06-03 21:52:43 -07:00
2024-06-03 23:57:48 -07:00
2024-10-21 09:36:20 -07:00
class MCTSAgent(Agent):
"""
An agent interface for proof search using monte carlo tree search
"""
@abstractmethod
def next_tactic(
self,
state: GoalState,
goal_id: int,
tested: Optional[List[Tactic]] = None,
) -> Optional[Tactic]:
"""
Implement this function to generate the next tactic for a goal given tactics already tested
"""
@abstractmethod
def reset(self):
"""
Called after search
"""
@abstractmethod
def estimate(self, state: SearchState) -> SearchState:
"""
Implement this function to estimate the value of a state
"""
@abstractmethod
def select(self, state: SearchState) -> list[SearchState]:
"""
Implement this function to select the best node within the subtree of the state.
Returns the path to the selected node from the given state.
"""
def backup(self, states: list[SearchState], value: float):
"""
Backup value of the state at the end of the states list.
"""
for state in states:
state.total_value += value
state.visit_count += 1
state.subtree_exhausted = all(child.subtree_exhausted for child in state.children) and state.exhausted
def search(self,
server: Server,
goal_state: GoalState,
max_steps: int = 100,
max_trials_per_goal: int = 5,
verbose: bool = False) -> SearchResult:
"""
Executes proof search on this state
"""
assert server.is_automatic(), "Search must be run in automatic mode"
n_goals_root = len(goal_state.goals)
time_start = time.time()
initial_state = SearchState(
2024-11-02 03:35:57 -07:00
goal_state=goal_state,
2024-10-21 09:36:20 -07:00
parent=None,
parent_goal_id=None,
priorities=[0.0 for _ in goal_state.goals]
)
initial_state = self.estimate(initial_state)
search_root = initial_state
for i_step in range(max_steps):
search_trajectory = self.select(search_root)
search_state = search_trajectory[-1]
assert isinstance(search_state, SearchState)
if search_state.is_solved:
return SearchResult(
n_goals_root=n_goals_root,
duration=time.time() - time_start,
success=True,
steps=i_step,
)
# Find the unsolved goal with the highest priority
goal_id = search_state.next_goal_id
if search_state.trials[goal_id] > max_trials_per_goal:
# force halt the search
tactic = None
else:
# Generate tactic for this goal
2024-11-02 03:35:57 -07:00
tactic = self.next_tactic(search_state.goal_state, goal_id, search_state.tested_tactics)
2024-10-21 09:36:20 -07:00
if verbose:
print(f"Next tactic: {tactic}")
if not tactic:
# resets the feedback
2024-11-02 03:35:57 -07:00
search_state.tactic_feedback = None
2024-10-21 09:36:20 -07:00
search_state.exhausted = True
search_state.subtree_exhausted = all(child.subtree_exhausted for child in search_state.children)
continue
assert tactic not in search_state.tested_tactics, "Tactic already seen!"
search_state.tested_tactics.append(tactic)
try:
search_state.trials[goal_id] += 1
2024-11-02 03:35:57 -07:00
state = search_state.goal_state
2024-10-21 09:36:20 -07:00
if verbose:
2024-11-02 03:35:57 -07:00
print(f"{state.state_id}.{goal_id}: {tactic} on {search_state.goal_state.goals[goal_id]}")
next_goal_state = server.goal_tactic(search_state.goal_state, goal_id, tactic)
2024-10-21 09:36:20 -07:00
# Generate priorities for the next goal state
priorities = [0.0 for _ in next_goal_state.goals] \
if len(next_goal_state.goals) <= 1 else \
self.guidance(next_goal_state)
parent = -1
next_state = SearchState(
2024-11-02 03:35:57 -07:00
goal_state=next_goal_state,
2024-10-21 09:36:20 -07:00
parent=parent,
parent_goal_id=goal_id,
priorities=priorities
)
next_state = self.estimate(next_state)
search_state.children.append(next_state)
self.backup(search_trajectory, next_state.total_value)
except TacticFailure as t:
if verbose:
print(f"Tactic failed: {t}")
2024-11-02 03:35:57 -07:00
search_state.tactic_feedback = str(t)
2024-10-21 09:36:20 -07:00
# try the next tactic. this one failed
except ServerError as e:
raise RuntimeError(f"While executing tactic: {tactic}") from e
if verbose:
print("Search iteration limit exhausted")
self.reset()
return SearchResult(
n_goals_root=n_goals_root,
duration=time.time() - time_start,
success=False,
steps=max_steps,
)
2024-06-03 21:52:43 -07:00
class DumbAgent(Agent):
def __init__(self):
super().__init__()
self.goal_tactic_id_map = collections.defaultdict(lambda : 0)
self.intros = [
"intro",
]
self.tactics = [
"intro h",
"cases h",
"apply Or.inl",
"apply Or.inr",
]
self.no_space_tactics = [
"assumption",
]
def next_tactic(
self,
state: GoalState,
goal_id: int,
) -> Optional[Tactic]:
2024-06-03 21:52:43 -07:00
key = (state.state_id, goal_id)
i = self.goal_tactic_id_map[key]
target = state.goals[goal_id].target
if target.startswith(''):
tactics = self.intros
elif ' ' in target:
tactics = self.tactics
else:
tactics = self.no_space_tactics
if i >= len(tactics):
return None
self.goal_tactic_id_map[key] = i + 1
return tactics[i]
2024-10-21 09:36:20 -07:00
class DumbMCTSAgent(MCTSAgent):
def __init__(self):
super().__init__()
self.goal_tactic_id_map = collections.defaultdict(lambda : 0)
self.intros = [
"intro",
]
self.tactics = [
"intro h",
"cases h",
"apply Or.inl",
"apply Or.inr",
]
self.no_space_tactics = [
"assumption",
]
self.c = 0.6
def estimate(self, state: SearchState) -> SearchState:
state.total_value = random.random()
return state
def select(self, state: SearchState) -> list[SearchState]:
"""
UCB scoring with taking the current state as one option, i.e. one child
"""
state_trajectory = [state]
current_state = state
current_state_ucb = (state.total_value / state.visit_count) + self.c * sqrt((log(state.visit_count) / state.visit_count))
while current_state.children:
avg_val = [child.total_value / child.visit_count for child in current_state.children]
visit_portions = [sqrt(log(current_state.visit_count) / child.visit_count) for child in current_state.children]
ucbs = [avg + self.c * visit for avg, visit in zip(avg_val, visit_portions, strict=True)]
child_idcs = [idx for idx in range(len(current_state.children)) if not current_state.children[idx].subtree_exhausted]
if not child_idcs:
return state_trajectory
child_idx = child_idcs[0]
for i in child_idcs:
if ucbs[i] > ucbs[child_idx]:
child_idx = i
if ucbs[child_idx] < current_state_ucb and not current_state.exhausted:
return state_trajectory
current_state_ucb = ucbs[child_idx]
current_state = current_state.children[child_idx]
state_trajectory.append(current_state)
return state_trajectory
def next_tactic(
self,
state: GoalState,
goal_id: int,
tested: Optional[List[Tactic]] = None
) -> Optional[Tactic]:
key = (state.state_id, goal_id)
i = self.goal_tactic_id_map[key]
target = state.goals[goal_id].target
if target.startswith(''):
tactics = self.intros
elif ' ' in target:
tactics = self.tactics
else:
tactics = self.no_space_tactics
if i >= len(tactics):
return None
self.goal_tactic_id_map[key] = i + 1
while tactics[i] in tested:
i += 1
if i >= len(tactics):
return None
return tactics[i]
2024-06-03 23:57:48 -07:00
2024-06-03 21:52:43 -07:00
class TestSearch(unittest.TestCase):
def test_solve(self):
server = Server()
agent = DumbAgent()
goal_state = server.goal_start("∀ (p q: Prop), p -> p")
flag = agent.search(
server=server,
goal_state=goal_state,
verbose=False)
2024-06-03 21:52:43 -07:00
#flag = agent.search(server=server, target="∀ (p q: Prop), Or p q -> Or q p", verbose=True)
self.assertTrue(flag)
def test_solve_big(self):
server = Server()
agent = DumbAgent()
goal_state = server.goal_start("∀ (p q: Prop), Or p q -> Or q p")
flag = agent.search(
server=server,
goal_state=goal_state,
verbose=False)
2024-06-03 21:52:43 -07:00
self.assertTrue(flag)
2024-10-21 09:36:20 -07:00
class TestMCTSSearch(unittest.TestCase):
def test_solve(self):
server = Server()
agent = DumbMCTSAgent()
goal_state = server.goal_start("∀ (p q: Prop), p -> p")
flag = agent.search(
server=server,
goal_state=goal_state,
verbose=False)
#flag = agent.search(server=server, target="∀ (p q: Prop), Or p q -> Or q p", verbose=True)
self.assertTrue(flag)
def test_solve_big(self):
server = Server()
agent = DumbMCTSAgent()
goal_state = server.goal_start("∀ (p q: Prop), Or p q -> Or q p")
flag = agent.search(
server=server,
goal_state=goal_state,
max_steps=200,
verbose=False)
2024-10-21 09:36:20 -07:00
self.assertTrue(flag)
2024-06-03 23:57:48 -07:00
2024-06-03 21:52:43 -07:00
if __name__ == '__main__':
unittest.main()