From 82d9f9200e4ad8ba3bd055225e952609c8983f0f Mon Sep 17 00:00:00 2001 From: Leni Aniva Date: Fri, 4 Oct 2024 18:45:13 -0700 Subject: [PATCH] refactor: Pass in `informal_{stmt,proof}` directly --- experiments/minif2f/main.py | 14 +++++++++----- experiments/minif2f/model/gen_tactic.py | 3 ++- experiments/minif2f/model/llm_agent.py | 10 ++++++---- 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/experiments/minif2f/main.py b/experiments/minif2f/main.py index a81ae2b..1ddf467 100755 --- a/experiments/minif2f/main.py +++ b/experiments/minif2f/main.py @@ -24,8 +24,8 @@ def read_test_data(use_valid: bool): def try_test_data(server, agent, entry: dict, max_steps: int, max_trials_per_goal: int) -> Optional[SearchResult]: command = entry["formal_statement"] print(command) - informal_stmt = entry["informal_stmt"] - informal_proof = entry["informal_proof"] + agent.informal_stmt = entry["informal_stmt"] + agent.informal_proof = entry["informal_proof"] goal_states = server.load_sorry(command) @@ -37,8 +37,6 @@ def try_test_data(server, agent, entry: dict, max_steps: int, max_trials_per_goa return agent.search( server=server, goal_state=goal_state, - informal_stmt=informal_stmt, - informal_proof=informal_proof, verbose=True, max_steps=max_steps, max_trials_per_goal=max_trials_per_goal @@ -87,7 +85,13 @@ def run_eval(args): use_llm=args.use_llm, feedback_turns=args.feedback_turns, ) - result = try_test_data(server, agent, datum, max_steps=args.max_steps, max_trials_per_goal=args.max_trials_per_goal) + result = try_test_data( + server, + agent, + datum, + max_steps=args.max_steps, + max_trials_per_goal=args.max_trials_per_goal, + ) print(colored(f"Result on {datum['id']}: {result}", "blue")) #server.gc() if result is None: diff --git a/experiments/minif2f/model/gen_tactic.py b/experiments/minif2f/model/gen_tactic.py index 8c115bc..d0ed476 100644 --- a/experiments/minif2f/model/gen_tactic.py +++ b/experiments/minif2f/model/gen_tactic.py @@ -97,7 +97,8 @@ def multi_turn_question(s, question_1, question_2): @sgl.function def select_tactic( s, server, state, goal_id, - informal_stmt: str = "", informal_proof: str = "", + informal_stmt: str = "", + informal_proof: str = "", feedback_turns: int = 5): s += sgl.system("You are an expert in Lean. Choose the next ONE tactic to run given the current proof state and goals.") diff --git a/experiments/minif2f/model/llm_agent.py b/experiments/minif2f/model/llm_agent.py index d662a8b..9c069a8 100644 --- a/experiments/minif2f/model/llm_agent.py +++ b/experiments/minif2f/model/llm_agent.py @@ -38,12 +38,14 @@ class LLMAgent(Agent): else: self.tactics = [] + self.informal_stmt = "" + self.informal_proof = "" + def next_tactic( self, state: GoalState, goal_id: int, - informal_stmt: str = "", - informal_proof: str = "") -> Optional[Tactic]: + ) -> Optional[Tactic]: key = (state.state_id, goal_id) i = self.goal_tactic_id_map[key] @@ -59,8 +61,8 @@ class LLMAgent(Agent): server=self.server, state=state, goal_id=goal_id, - informal_stmt=informal_stmt, - informal_proof=informal_proof, + informal_stmt=self.informal_stmt, + informal_proof=self.informal_proof, feedback_turns=self.feedback_turns) tactic, new_state = s.ret_value for m in s.messages():