diff --git a/examples_search/miniF2F_search.py b/examples_search/miniF2F_search.py index 0bde569..ee8a3f0 100755 --- a/examples_search/miniF2F_search.py +++ b/examples_search/miniF2F_search.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 import subprocess, json, argparse +from typing import Optional from pathlib import Path from pantograph.server import Server from pantograph.search import SearchResult @@ -16,10 +17,16 @@ def read_test_data(use_valid: bool): with open(jsonl_path, 'r') as f: return [json.loads(l) for l in list(f)] -def try_test_data(server, agent, entry: dict, max_steps: int) -> SearchResult: +def try_test_data(server, agent, entry: dict, max_steps: int) -> Optional[SearchResult]: e = entry["formal_statement"] informal_stmt = entry["informal_stmt"] informal_proof = entry["informal_proof"] + + key_position = e.find('theorem') + if key_position == -1: + # Can't output anything for this one + return None + e = e[key_position:] key_theorem, name, e = e.split(' ', 2) e, tail = e.split(':=', 1) target = "forall " + ','.join(e.rsplit(':', 1)) @@ -45,7 +52,7 @@ if __name__ == '__main__': parser.add_argument('--use-hammer', action='store_true') parser.add_argument('--validation', action='store_true') parser.add_argument('--use-llm', action='store_true') - parser.add_argument('-s', '--max-steps', default=1000) + parser.add_argument('-s', '--max-steps', default=200) args = parser.parse_args() project_path, lean_path = get_project_and_lean_path() @@ -56,7 +63,14 @@ if __name__ == '__main__': server = Server(imports=["Mathlib"], project_path=project_path, lean_path=lean_path) agent = LLMAgent(server, use_hammer=args.use_hammer, use_llm=args.use_llm) for datum in test_data: - result = try_test_data(server, agent, datum, max_steps=args.max_steps) file_name = output_file_name(datum, args.use_hammer, args.use_llm) - with open(file_name, 'w') as f: - json.dump({ 'id': datum['id'], 'success': result.success, 'steps': result.steps }, f) + if file_name.is_file(): + print(f"Skipping {datum['id']}") + continue + result = try_test_data(server, agent, datum, max_steps=args.max_steps) + if result is None: + with open(file_name + '-placeholder', 'w') as f: + json.dump({ 'id': datum['id'] }, f) + else: + with open(file_name, 'w') as f: + json.dump({ 'id': datum['id'], 'success': result.success, 'steps': result.steps }, f) diff --git a/pantograph/search.py b/pantograph/search.py index e61f1bd..fdca450 100644 --- a/pantograph/search.py +++ b/pantograph/search.py @@ -17,6 +17,13 @@ class SearchState: def __post_init__(self): assert len(self.priorities) == len(self.state.goals) self.solved = [False for _ in self.state.goals] + self.trials = [0 for _ in self.state.goals] + + @property + def next_goal_id(self) -> int: + goal_id, _ = max([(i, prio) for i, prio in enumerate(self.priorities) if not self.solved[i]], + key=lambda x:x[1]) + return goal_id @property def is_root(self) -> bool: @@ -55,7 +62,8 @@ class Agent: target: Expr, informal_stmt: str = "", informal_proof: str = "", - max_steps: int = 1000, + max_steps: int = 100, + max_trial_per_goal: int = 5, verbose: bool = False) -> SearchResult: search_stack = [SearchState(state=server.goal_start(target), @@ -89,11 +97,15 @@ class Agent: continue # Find the unsolved goal with the highest priority - goal_id, _ = max([(i, prio) for i, prio in enumerate(search_state.priorities) if not search_state.solved[i]], - key=lambda x:x[1]) + goal_id = search_state.next_goal_id + + if search_state.trials[goal_id] > max_trial_per_goal: + # force halt the search + tactic = None + else: + # Generate tactic for this goal + tactic = self.next_tactic(search_state.state, goal_id, informal_stmt, informal_proof) - # Generate tactic for this goal - tactic = self.next_tactic(search_state.state, goal_id, informal_stmt, informal_proof) print("????next tactic: ", tactic) if not tactic: # pop the current state and continue to the next diff --git a/pantograph/server.py b/pantograph/server.py index 109386e..2952fe8 100644 --- a/pantograph/server.py +++ b/pantograph/server.py @@ -98,6 +98,7 @@ class Server: def goal_start(self, expr: Expr) -> GoalState: result = self.run('goal.start', {"expr": str(expr)}) if "error" in result: + print(f"Cannot start goal: {expr}") raise ServerError(result["desc"]) return GoalState(state_id=result["stateId"], goals=[Goal.sentence(expr)], _sentinel=self.to_remove_goal_states)