From ce633fecdaefdde3b47ed08638dbcb98fab90956 Mon Sep 17 00:00:00 2001 From: Leni Aniva Date: Wed, 5 Jun 2024 14:19:18 -0700 Subject: [PATCH] feat: Add ablation testing --- examples_search/miniF2F_search.py | 34 ++++++++++++++++++++++++++----- pantograph/search.py | 13 ++++++++---- pantograph/search_llm.py | 25 +++++++++++++++-------- 3 files changed, 55 insertions(+), 17 deletions(-) diff --git a/examples_search/miniF2F_search.py b/examples_search/miniF2F_search.py index 8a1d907..d9e4b2a 100755 --- a/examples_search/miniF2F_search.py +++ b/examples_search/miniF2F_search.py @@ -1,8 +1,9 @@ #!/usr/bin/env python3 -import subprocess, json +import subprocess, json, argparse from pathlib import Path from pantograph.server import Server +from pantograph.search import SearchResult from pantograph.search_llm import LLMAgent def get_project_and_lean_path(): @@ -15,7 +16,7 @@ def read_test_data(): with open(jsonl_path, 'r') as f: return [json.loads(l) for l in list(f)] -def try_test_data(server, agent, entry) -> bool: +def try_test_data(server, agent, entry: dict, max_steps: int) -> SearchResult: e = entry["formal_statement"] informal_stmt = entry["informal_stmt"] informal_proof = entry["informal_proof"] @@ -24,14 +25,37 @@ def try_test_data(server, agent, entry) -> bool: target = "forall " + ','.join(e.rsplit(':', 1)) print(f"Target: {target}") agent = LLMAgent(server) - return agent.search(server=server, target=target, informal_stmt = informal_stmt, informal_proof = informal_proof,verbose=True) + return agent.search(server=server, target=target, informal_stmt = informal_stmt, informal_proof = informal_proof,verbose=True, max_steps=max_steps) + +def output_file_name(datum, use_hammer: bool, use_llm: bool): + name = datum["id"] + folder = 'output' + if use_hammer: + folder += '-hammer' + if use_llm: + folder += '-llm' + folder = Path(__file__).parent / folder + folder.mkdir(exist_ok=True, parents=True) + return folder / f"{name}.json" if __name__ == '__main__': + parser = argparse.ArgumentParser( + prog='MiniF2F Search', + description='Executes LLM on MiniF2F Search') + parser.add_argument('--use-hammer', action='store_true') + parser.add_argument('--use-llm', action='store_true') + parser.add_argument('-s', '--max-steps', default=1000) + args = parser.parse_args() + project_path, lean_path = get_project_and_lean_path() print(f"$PWD: {project_path}") print(f"$LEAN_PATH: {lean_path}") test_data = read_test_data() server = Server(imports=["Mathlib"], project_path=project_path, lean_path=lean_path) - agent = LLMAgent(server) - try_test_data(server, agent, test_data[0]) + agent = LLMAgent(server, use_hammer=args.use_hammer, use_llm=args.use_llm) + for datum in test_data[:1]: + result = try_test_data(server, agent, datum, max_steps=args.max_steps) + file_name = output_file_name(datum, args.use_hammer, args.use_llm) + with open(file_name, 'w') as f: + json.dump({ 'id': datum['id'], 'success': result.success, 'steps': result.steps }, f) diff --git a/pantograph/search.py b/pantograph/search.py index ffa6f5c..e61f1bd 100644 --- a/pantograph/search.py +++ b/pantograph/search.py @@ -26,6 +26,11 @@ class SearchState: def is_solved(self) -> bool: return all(self.solved) +@dataclass(frozen=True) +class SearchResult: + + success: bool + steps: int class Agent: @@ -51,7 +56,7 @@ class Agent: informal_stmt: str = "", informal_proof: str = "", max_steps: int = 1000, - verbose: bool = False) -> bool: + verbose: bool = False) -> SearchResult: search_stack = [SearchState(state=server.goal_start(target), parent=None, @@ -76,7 +81,7 @@ class Agent: if verbose: print("Search complete: Root state solved") self.reset() - return True + return SearchResult(success=True, steps=i_step) search_stack.pop(-1) assert not search_stack[search_state.parent].solved[search_state.parent_goal_id] @@ -97,7 +102,7 @@ class Agent: if verbose: print("Tactic list has been exhausted") self.reset() - return False + return SearchResult(success=False, steps=i_step) continue try: @@ -123,7 +128,7 @@ class Agent: print("Search iteration limit exhausted") self.reset() - return False + return SearchResult(success=False, steps=max_steps) class DumbAgent(Agent): diff --git a/pantograph/search_llm.py b/pantograph/search_llm.py index 9b50b27..6016cba 100644 --- a/pantograph/search_llm.py +++ b/pantograph/search_llm.py @@ -8,26 +8,35 @@ import sglang as sgl class LLMAgent(Agent): - def __init__(self, server): + def __init__(self, server, + use_hammer=True, use_llm=True): super().__init__() self.n_trials = 5 self.server = server sgl.set_default_backend(sgl.OpenAI("gpt-4")) self.goal_tactic_id_map = collections.defaultdict(lambda : 0) - self.tactics = [ - "aesop", - #"simp", - #"rfl", - #"decide", - ] + + self.use_hammer = use_hammer + self.use_llm = use_llm + if use_hammer: + self.tactics = [ + "aesop", + #"simp", + #"rfl", + #"decide", + ] + else: + self.tactics = [] def next_tactic(self, state: GoalState, goal_id: int, informal_stmt:str="", informal_proof:str="") -> Optional[Tactic]: key = (state.state_id, goal_id) i = self.goal_tactic_id_map[key] target = state.goals[goal_id].target - if i >= len(self.tactics): + if i >= len(self.tactics) and not self.use_llm: + return None + elif i >= len(self.tactics): new_state = None for ii in range(self.n_trials): print(f"===============trail {str(ii)}============")