From 6d60651ed1001c42bb2c926be5247fd66bd0e95f Mon Sep 17 00:00:00 2001 From: Chuyue Sun Date: Wed, 5 Jun 2024 11:39:08 -0700 Subject: [PATCH] add informal hints for search agent --- examples_search/miniF2F_search.py | 4 +++- pantograph/gen_tactic.py | 7 +++++-- pantograph/search.py | 8 +++++--- pantograph/search_llm.py | 5 +++-- 4 files changed, 16 insertions(+), 8 deletions(-) diff --git a/examples_search/miniF2F_search.py b/examples_search/miniF2F_search.py index 3c57019..8a1d907 100755 --- a/examples_search/miniF2F_search.py +++ b/examples_search/miniF2F_search.py @@ -17,12 +17,14 @@ def read_test_data(): def try_test_data(server, agent, entry) -> bool: e = entry["formal_statement"] + informal_stmt = entry["informal_stmt"] + informal_proof = entry["informal_proof"] key_theorem, name, e = e.split(' ', 2) e, tail = e.split(':=', 1) target = "forall " + ','.join(e.rsplit(':', 1)) print(f"Target: {target}") agent = LLMAgent(server) - return agent.search(server=server, target=target, verbose=True) + return agent.search(server=server, target=target, informal_stmt = informal_stmt, informal_proof = informal_proof,verbose=True) if __name__ == '__main__': project_path, lean_path = get_project_and_lean_path() diff --git a/pantograph/gen_tactic.py b/pantograph/gen_tactic.py index 8762340..a1c7ee8 100644 --- a/pantograph/gen_tactic.py +++ b/pantograph/gen_tactic.py @@ -88,7 +88,7 @@ def multi_turn_question(s, question_1, question_2): @sgl.function -def select_tactic(s, server, state, goal_id, feedback_turns = 5): +def select_tactic(s, server, state, goal_id,informal_stmt="", informal_proof="", feedback_turns = 5): s += sgl.system("You are an expert in Lean. Choose the next ONE tactic to run given the current proof state and goals.") s += sgl.user(LEAN4_REWRITE) @@ -96,7 +96,10 @@ def select_tactic(s, server, state, goal_id, feedback_turns = 5): s += sgl.assistant("```intros a b h```") s += sgl.user("The current proof state: GoalState(state_id=1, goals=[Goal(variables=[Variable(t='Nat', v=None, name='a'), Variable(t='Nat', v=None, name='b'), Variable(t='b = 2', v=None, name='h')], target='1 + a + 1 = a + b', name=None, is_conversion=False)])") s += sgl.assistant('TacticCalc("1 + a + 1 = a + 1 + 1")') - s += sgl.user("The current proof state: " + str(state)) + if informal_stmt and informal_proof: + s += sgl.user("informal theorem statement: "+ informal_stmt) + s += sgl.user("informal proof: " + informal_proof) + s += sgl.user("The current proof state: " + str(state) + "") for i in range(feedback_turns): with s.copy() as tmp: tmp += sgl.assistant(sgl.gen("tactic", max_tokens=64)) diff --git a/pantograph/search.py b/pantograph/search.py index c48c6e4..7f36673 100644 --- a/pantograph/search.py +++ b/pantograph/search.py @@ -29,7 +29,7 @@ class SearchState: class Agent: - def next_tactic(self, state: GoalState, goal_id: int) -> Optional[Tactic]: + def next_tactic(self, state: GoalState, goal_id: int, informal_stmt:str, informal_proof:str) -> Optional[Tactic]: """ Implement this function to generate the next tactic for a goal """ @@ -48,6 +48,8 @@ class Agent: def search(self, server: Server, target: Expr, + informal_stmt: str = "", + informal_proof: str = "", max_steps: int = 1000, verbose: bool = False) -> bool: @@ -84,7 +86,7 @@ class Agent: key=lambda x:x[1]) # Generate tactic for this goal - tactic = self.next_tactic(search_state.state, goal_id) + tactic = self.next_tactic(search_state.state, goal_id, informal_stmt, informal_proof) if not tactic: # pop the current state and continue to the next search_stack.pop(-1) @@ -140,7 +142,7 @@ class DumbAgent(Agent): "assumption", ] - def next_tactic(self, state: GoalState, goal_id: int) -> Optional[Tactic]: + def next_tactic(self, state: GoalState, goal_id: int, informal_stmt:str, informal_proof:str) -> Optional[Tactic]: key = (state.state_id, goal_id) i = self.goal_tactic_id_map[key] diff --git a/pantograph/search_llm.py b/pantograph/search_llm.py index b4a6ce1..89ea810 100644 --- a/pantograph/search_llm.py +++ b/pantograph/search_llm.py @@ -21,6 +21,7 @@ class LLMAgent(Agent): self.tactics = [ "intro h", "cases h", + "simp", "apply Or.inl", "apply Or.inr", ] @@ -28,7 +29,7 @@ class LLMAgent(Agent): "assumption", ] - def next_tactic(self, state: GoalState, goal_id: int) -> Optional[Tactic]: + def next_tactic(self, state: GoalState, goal_id: int, informal_stmt:str, informal_proof:str) -> Optional[Tactic]: key = (state.state_id, goal_id) i = self.goal_tactic_id_map[key] @@ -47,7 +48,7 @@ class LLMAgent(Agent): new_state = None for ii in range(self.n_trials): print(f"===============trail {str(ii)}============") - s = select_tactic.run(server = self.server, state=state, goal_id = goal_id) + s = select_tactic.run(server = self.server, state=state, goal_id = goal_id, informal_stmt=informal_stmt, informal_proof=informal_proof) tactic, new_state = s.ret_value for m in s.messages(): print(m["role"], ":", m["content"])