diff --git a/examples_search/miniF2F_search.py b/examples_search/miniF2F_search.py index ee8a3f0..8afc453 100755 --- a/examples_search/miniF2F_search.py +++ b/examples_search/miniF2F_search.py @@ -3,7 +3,7 @@ import subprocess, json, argparse from typing import Optional from pathlib import Path -from pantograph.server import Server +from pantograph.server import Server, ServerError from pantograph.search import SearchResult from pantograph.search_llm import LLMAgent @@ -17,22 +17,47 @@ def read_test_data(use_valid: bool): with open(jsonl_path, 'r') as f: return [json.loads(l) for l in list(f)] -def try_test_data(server, agent, entry: dict, max_steps: int) -> Optional[SearchResult]: +def inplace_to_statement(expr: str) -> str: + bracket = 0 + i = 0 + while i < len(expr): + if expr[i] == ':' and bracket == 0: + break + elif expr[i] == '(': + bracket += 1 + elif expr[i] == ')': + bracket -= 1 + i += 1 + if i == 0: + return expr[1:] + if i == len(expr): + return expr + + return 'forall ' + expr[:i] + ' , ' + expr[i+1:] + + +def try_test_data(server, agent, entry: dict, max_steps: int, max_trials_per_goal: int) -> Optional[SearchResult]: e = entry["formal_statement"] + print(e) informal_stmt = entry["informal_stmt"] informal_proof = entry["informal_proof"] key_position = e.find('theorem') - if key_position == -1: + if key_position != 0: # Can't output anything for this one return None e = e[key_position:] + # remove the tail := sorry + e, tail = e.rsplit(':=', 1) + # remove the head key_theorem, name, e = e.split(' ', 2) - e, tail = e.split(':=', 1) - target = "forall " + ','.join(e.rsplit(':', 1)) + target = inplace_to_statement(e.strip()) print(f"Target: {target}") - agent = LLMAgent(server) - return agent.search(server=server, target=target, informal_stmt = informal_stmt, informal_proof = informal_proof,verbose=True, max_steps=max_steps) + try: + return agent.search(server=server, target=target, informal_stmt = informal_stmt, informal_proof = informal_proof,verbose=True, + max_steps=max_steps, max_trials_per_goal=max_trials_per_goal) + except ServerError as e: + return None def output_file_name(datum, use_hammer: bool, use_llm: bool): name = datum["id"] @@ -53,6 +78,7 @@ if __name__ == '__main__': parser.add_argument('--validation', action='store_true') parser.add_argument('--use-llm', action='store_true') parser.add_argument('-s', '--max-steps', default=200) + parser.add_argument('-t', '--max-trials-per-goal', default=4) args = parser.parse_args() project_path, lean_path = get_project_and_lean_path() @@ -60,17 +86,20 @@ if __name__ == '__main__': print(f"$LEAN_PATH: {lean_path}") test_data = read_test_data(args.validation) - server = Server(imports=["Mathlib"], project_path=project_path, lean_path=lean_path) - agent = LLMAgent(server, use_hammer=args.use_hammer, use_llm=args.use_llm) for datum in test_data: file_name = output_file_name(datum, args.use_hammer, args.use_llm) + placeholder_file_name = file_name.with_suffix('.placeholder') if file_name.is_file(): print(f"Skipping {datum['id']}") continue - result = try_test_data(server, agent, datum, max_steps=args.max_steps) + server = Server(imports=["Example"], project_path=project_path, lean_path=lean_path, options=["maxHeartbeats=0"]) + agent = LLMAgent(server, use_hammer=args.use_hammer, use_llm=args.use_llm) + result = try_test_data(server, agent, datum, max_steps=args.max_steps, max_trials_per_goal=args.max_trials_per_goal) if result is None: - with open(file_name + '-placeholder', 'w') as f: + with open(placeholder_file_name, 'w') as f: json.dump({ 'id': datum['id'] }, f) else: + if placeholder_file_name.is_file(): + placeholder_file_name.unlink() with open(file_name, 'w') as f: json.dump({ 'id': datum['id'], 'success': result.success, 'steps': result.steps }, f) diff --git a/pantograph/search.py b/pantograph/search.py index fdca450..ee65468 100644 --- a/pantograph/search.py +++ b/pantograph/search.py @@ -63,7 +63,7 @@ class Agent: informal_stmt: str = "", informal_proof: str = "", max_steps: int = 100, - max_trial_per_goal: int = 5, + max_trials_per_goal: int = 5, verbose: bool = False) -> SearchResult: search_stack = [SearchState(state=server.goal_start(target), @@ -99,7 +99,7 @@ class Agent: # Find the unsolved goal with the highest priority goal_id = search_state.next_goal_id - if search_state.trials[goal_id] > max_trial_per_goal: + if search_state.trials[goal_id] > max_trials_per_goal: # force halt the search tactic = None else: @@ -118,6 +118,7 @@ class Agent: continue try: + search_state.trials[goal_id] += 1 state = search_state.state if verbose: print(f"{state.state_id}.{goal_id}: {tactic} on {search_state.state.goals[goal_id]}") diff --git a/pantograph/search_llm.py b/pantograph/search_llm.py index 6016cba..041ecac 100644 --- a/pantograph/search_llm.py +++ b/pantograph/search_llm.py @@ -13,7 +13,9 @@ class LLMAgent(Agent): super().__init__() self.n_trials = 5 self.server = server - sgl.set_default_backend(sgl.OpenAI("gpt-4")) + + if use_llm: + sgl.set_default_backend(sgl.OpenAI("gpt-4")) self.goal_tactic_id_map = collections.defaultdict(lambda : 0) @@ -37,6 +39,7 @@ class LLMAgent(Agent): if i >= len(self.tactics) and not self.use_llm: return None elif i >= len(self.tactics): + assert self.use_llm new_state = None for ii in range(self.n_trials): print(f"===============trail {str(ii)}============")