diff --git a/experiments/minif2f/main.py b/experiments/minif2f/main.py index ea1fd50..296c661 100755 --- a/experiments/minif2f/main.py +++ b/experiments/minif2f/main.py @@ -3,9 +3,10 @@ import subprocess, json, argparse from typing import Optional from pathlib import Path -from pantograph.server import Server, ServerError +from pantograph.server import Server, ServerError, DEFAULT_CORE_OPTIONS from pantograph.search import SearchResult from model.llm_agent import LLMAgent +from model.options import CORE_OPTIONS PATH_EXPERIMENT = Path(__file__).parent.resolve() @@ -72,8 +73,18 @@ def run_eval(args): if file_name.is_file(): print(f"Skipping {datum['id']}") continue - server = Server(imports=["MiniF2F"], project_path=project_path, lean_path=lean_path) - agent = LLMAgent(server, use_hammer=args.use_hammer, use_llm=args.use_llm) + server = Server( + imports=["Mathlib", "Aesop"], + project_path=project_path, + lean_path=lean_path, + core_options=CORE_OPTIONS, + ) + agent = LLMAgent( + server, + use_hammer=args.use_hammer, + use_llm=args.use_llm, + feedback_turns=args.feedback_turns, + ) result = try_test_data(server, agent, datum, max_steps=args.max_steps, max_trials_per_goal=args.max_trials_per_goal) #server.gc() if result is None: @@ -87,8 +98,9 @@ def run_eval(args): if __name__ == '__main__': parser = argparse.ArgumentParser( - prog='MiniF2F Search', - description='Executes LLM on MiniF2F Search') + prog='MiniF2F Search', + description='Executes LLM on MiniF2F Search', + ) parser.add_argument('--use-hammer', action='store_true') parser.add_argument( '--dry-run', @@ -96,8 +108,9 @@ if __name__ == '__main__': help="List the data used, but don't run") parser.add_argument('--validation', action='store_true') parser.add_argument('--use-llm', action='store_true') - parser.add_argument('-s', '--max-steps', default=50) - parser.add_argument('-t', '--max-trials-per-goal', default=2) + parser.add_argument('--max-steps', default=50) + parser.add_argument('--max-trials-per-goal', default=2) + parser.add_argument('--feedback-turns', default=2) args = parser.parse_args() if args.dry_run: diff --git a/experiments/minif2f/model/gen_tactic.py b/experiments/minif2f/model/gen_tactic.py index a1c7ee8..1689d28 100644 --- a/experiments/minif2f/model/gen_tactic.py +++ b/experiments/minif2f/model/gen_tactic.py @@ -1,7 +1,12 @@ +""" +Tactic generation functions for the LLM agent +""" from pantograph.server import Server, ServerError, TacticFailure from pantograph.expr import Variable, Goal, TacticCalc -import unittest import sglang as sgl +from termcolor import colored +import unittest +from .options import CORE_OPTIONS LEAN4_INTRO = '''/-- A sequence `u` of real numbers converges to `l` if `∀ ε > 0, ∃ N, ∀ n ≥ N, |u_n - l| ≤ ε`. This condition will be spelled `seq_limit u l`. -/ @@ -78,6 +83,8 @@ example (n : Nat) (h : n = 0) (t : Tuple α n) : Tuple α 0 := by exact t ''' +PREFIX_CURRENT_GOAL = "The current goal: " + @sgl.function def multi_turn_question(s, question_1, question_2): s += sgl.system("You are a helpful assistant.") @@ -88,34 +95,42 @@ def multi_turn_question(s, question_1, question_2): @sgl.function -def select_tactic(s, server, state, goal_id,informal_stmt="", informal_proof="", feedback_turns = 5): - +def select_tactic( + s, server, state, goal_id, + informal_stmt: str = "", informal_proof: str = "", + feedback_turns: int = 5): + s += sgl.system("You are an expert in Lean. Choose the next ONE tactic to run given the current proof state and goals.") s += sgl.user(LEAN4_REWRITE) - s += sgl.user("The current proof state: GoalState(state_id=0, goals=[Goal(variables=[], target='∀ (a b: Nat), (b = 2) -> 1 + a + 1 = a + b', name=None, is_conversion=False)])") - s += sgl.assistant("```intros a b h```") - s += sgl.user("The current proof state: GoalState(state_id=1, goals=[Goal(variables=[Variable(t='Nat', v=None, name='a'), Variable(t='Nat', v=None, name='b'), Variable(t='b = 2', v=None, name='h')], target='1 + a + 1 = a + b', name=None, is_conversion=False)])") - s += sgl.assistant('TacticCalc("1 + a + 1 = a + 1 + 1")') + #s += sgl.user("The current proof state: GoalState(state_id=0, goals=[Goal(variables=[], target='∀ (a b: Nat), (b = 2) -> 1 + a + 1 = a + b', name=None, is_conversion=False)])") + #s += sgl.assistant("```intros a b h```") + #s += sgl.user("The current proof state: GoalState(state_id=1, goals=[Goal(variables=[Variable(t='Nat', v=None, name='a'), Variable(t='Nat', v=None, name='b'), Variable(t='b = 2', v=None, name='h')], target='1 + a + 1 = a + b', name=None, is_conversion=False)])") + #s += sgl.assistant('TacticCalc("1 + a + 1 = a + 1 + 1")') + s += sgl.user(f"{PREFIX_CURRENT_GOAL}p : Prop\n⊢ ∀ (q: Prop), Or p q -> Or q p") + s += sgl.assistant('```\nintro q\n```') + s += sgl.user(f"{PREFIX_CURRENT_GOAL}a b c : Nat\n⊢ a + b + c = a + c + b") + s += sgl.assistant('```\nrw [Nat.add_assoc, Nat.add_comm b, ← Nat.add_assoc]\n```') if informal_stmt and informal_proof: - s += sgl.user("informal theorem statement: "+ informal_stmt) + s += sgl.user("informal theorem statement: " + informal_stmt) s += sgl.user("informal proof: " + informal_proof) - s += sgl.user("The current proof state: " + str(state) + "") + s += sgl.user(f"{PREFIX_CURRENT_GOAL}{state.goals[goal_id]}") for i in range(feedback_turns): with s.copy() as tmp: tmp += sgl.assistant(sgl.gen("tactic", max_tokens=64)) # print("==tmp===") # print(tmp["tactic"]) - tactic = extract_code_from_llm_output(tmp["tactic"]) - s += sgl.assistant("```"+tactic+"```") + tactic = extract_code_from_llm_output(tmp["tactic"]).strip() + s += sgl.assistant(f"```\n{tactic}\n```") success, new_state = apply_tactic(server, state, goal_id, tactic) # print("===execute===") # print(success, new_state ) if not success: + print(colored("[Tactic]", "red"), tactic) with s.user(): - s += "This answer got Lean compile error:\n" + str(new_state) + "\n" + s += f"This answer got a Lean compile error:\n{new_state}\n" s += "Please try again by taking the Lean compiler feedback." - else: + print(colored("[Tactic]", "green"), tactic) return tactic, new_state return None, None @@ -127,7 +142,7 @@ def apply_tactic(server, state, goal_id, tactic): except TacticFailure as e: return False, e return True, new_state - + def extract_code_from_llm_output(reply): i = reply.find("```lean") if i != -1: @@ -149,7 +164,7 @@ class TestServerSGL(unittest.TestCase): n_trails = 5 sgl.set_default_backend(sgl.OpenAI("gpt-4")) - server = Server() + server = Server(core_options=CORE_OPTIONS) state0 = server.goal_start("∀ (a b: Nat), (b = 2) -> 1 + a + 1 = a + b") print("==========state0============") print(state0) @@ -187,7 +202,7 @@ class TestServerSGL(unittest.TestCase): print("\n-- new state --\n", state3) break - + except ServerError as e: print(f"server error: {e}") continue @@ -207,14 +222,14 @@ class TestServerSGL(unittest.TestCase): print("\n-- new state --\n", state4) break - + except ServerError as e: print(f"server error: {e}") continue state4 = server.goal_tactic(state3, goal_id=0, tactic="rw [Nat.add_assoc]") print("==========state4============") - print(state4) + print(state4) self.assertTrue(state4.is_solved) @@ -232,8 +247,7 @@ class TestServerSGL(unittest.TestCase): print("\n-- answer_1 --\n", state["answer_1"]) - + if __name__ == '__main__': unittest.main() - diff --git a/experiments/minif2f/model/llm_agent.py b/experiments/minif2f/model/llm_agent.py index 3105d69..af09302 100644 --- a/experiments/minif2f/model/llm_agent.py +++ b/experiments/minif2f/model/llm_agent.py @@ -4,6 +4,7 @@ from pantograph.search import Agent from pantograph.server import Server, TacticFailure, ServerError from pantograph.expr import Expr, Tactic, GoalState from .gen_tactic import LEAN4_REWRITE, select_tactic +from .options import CORE_OPTIONS import sglang as sgl class LLMAgent(Agent): @@ -12,7 +13,9 @@ class LLMAgent(Agent): """ def __init__(self, server, - use_hammer=True, use_llm=True): + use_hammer=True, + use_llm=True, + feedback_turns=3): super().__init__() self.n_trials = 5 self.server = server @@ -24,17 +27,23 @@ class LLMAgent(Agent): self.use_hammer = use_hammer self.use_llm = use_llm + self.feedback_turns = feedback_turns if use_hammer: self.tactics = [ "aesop", - #"simp", + "simp", #"rfl", #"decide", ] else: self.tactics = [] - def next_tactic(self, state: GoalState, goal_id: int, informal_stmt:str="", informal_proof:str="") -> Optional[Tactic]: + def next_tactic( + self, + state: GoalState, + goal_id: int, + informal_stmt: str = "", + informal_proof: str = "") -> Optional[Tactic]: key = (state.state_id, goal_id) i = self.goal_tactic_id_map[key] @@ -46,7 +55,13 @@ class LLMAgent(Agent): new_state = None for ii in range(self.n_trials): print(f"===============trail {str(ii)}============") - s = select_tactic.run(server = self.server, state=state, goal_id = goal_id, informal_stmt=informal_stmt, informal_proof=informal_proof) + s = select_tactic.run( + server=self.server, + state=state, + goal_id=goal_id, + informal_stmt=informal_stmt, + informal_proof=informal_proof, + feedback_turns=self.feedback_turns) tactic, new_state = s.ret_value for m in s.messages(): print(m["role"], ":", m["content"]) @@ -78,7 +93,7 @@ class TestSearch(unittest.TestCase): def test_solve(self): - server = Server() + server = Server(core_options=CORE_OPTIONS) agent = LLMAgent(server, use_hammer=False) goal_state = server.goal_start("∀ (p q: Prop), p -> p") flag = agent.search(server=server, goal_state=goal_state, verbose=True) @@ -86,7 +101,7 @@ class TestSearch(unittest.TestCase): self.assertTrue(flag) def test_solve_big(self): - server = Server() + server = Server(core_options=CORE_OPTIONS) agent = LLMAgent(server, use_hammer=False) goal_state = server.goal_start("∀ (p q: Prop), Or p q -> Or q p") flag = agent.search(server=server, goal_state=goal_state, verbose=True) diff --git a/experiments/minif2f/model/options.py b/experiments/minif2f/model/options.py new file mode 100644 index 0000000..ef5755e --- /dev/null +++ b/experiments/minif2f/model/options.py @@ -0,0 +1,2 @@ +from pantograph.server import DEFAULT_CORE_OPTIONS +CORE_OPTIONS = DEFAULT_CORE_OPTIONS + ["tactic.hygienic=false"] diff --git a/pantograph/search.py b/pantograph/search.py index 00b4284..a166980 100644 --- a/pantograph/search.py +++ b/pantograph/search.py @@ -1,3 +1,4 @@ +from abc import abstractmethod from dataclasses import dataclass from typing import Optional import collections, unittest @@ -44,7 +45,9 @@ class Agent: """ An agent interface for proof search """ + tactic_feedback: Optional[str] = None + @abstractmethod def next_tactic( self, state: GoalState, @@ -54,14 +57,15 @@ class Agent: """ Implement this function to generate the next tactic for a goal """ - return None + @abstractmethod def guidance(self, state: GoalState) -> list[float]: """ Return a list of priorities determining which goal should be searched first. This will not be called on states with one or zero goals. """ return [0.0 for _ in state.goals] + @abstractmethod def reset(self): """ Called after search @@ -116,11 +120,13 @@ class Agent: if verbose: print(f"Next tactic: {tactic}") if not tactic: + # resets the feedback + self.tactic_feedback = None # pop the current state and continue to the next search_stack.pop(-1) if not search_stack: if verbose: - print("Tactic list has been exhausted") + print("Search stack has been exhausted") self.reset() return SearchResult(success=False, steps=i_step) continue @@ -147,6 +153,7 @@ class Agent: except TacticFailure as t: if verbose: print(f"Tactic failed: {t}") + self.tactic_feedback = str(t) # try the next tactic. this one failed if verbose: diff --git a/pantograph/server.py b/pantograph/server.py index a94b593..e3770cc 100644 --- a/pantograph/server.py +++ b/pantograph/server.py @@ -19,6 +19,8 @@ class TacticFailure(Exception): class ServerError(Exception): pass +DEFAULT_CORE_OPTIONS=["maxHeartbeats=0", "maxRecDepth=10000"] + class Server: def __init__(self, @@ -28,7 +30,7 @@ class Server: # Options for executing the REPL. # Set `{ "automaticMode" : False }` to handle resumption by yourself. options={}, - core_options=["maxHeartbeats=0"], + core_options=DEFAULT_CORE_OPTIONS, timeout=60, maxread=1000000): """ diff --git a/poetry.lock b/poetry.lock index 59bc761..317d065 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3250,4 +3250,4 @@ test = ["websockets"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "1ce8e928cff885e8c66d9c353e982e91e84ae84c91e96860aa3ca5a885bb0d2e" +content-hash = "b198bb707b86539e6c8edfe2b7377d47387aaaf053bb68b135ccd15361736030" diff --git a/pyproject.toml b/pyproject.toml index 9a6fc33..4a51591 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ tiktoken = "^0.7.0" torch = "2.2.1" wandb = "0.17.0" # vllm = "0.4.1" +termcolor = "^2.4.0" [build-system] requires = ["poetry-core"]