From a30225069abce31db350e0829a4f9308bc19872e Mon Sep 17 00:00:00 2001 From: Leni Aniva Date: Thu, 3 Oct 2024 12:53:07 -0700 Subject: [PATCH 1/7] refactor: All MiniF2F into its own directory --- experiments/minif2f/MiniF2F/.gitignore | 3 + experiments/minif2f/MiniF2F/MiniF2F.lean | 2 + .../minif2f/MiniF2F/lake-manifest.json | 75 +++++++++++++++++++ experiments/minif2f/MiniF2F/lakefile.lean | 12 +++ experiments/minif2f/MiniF2F/lean-toolchain | 1 + experiments/minif2f/README.md | 9 +++ experiments/minif2f/main.py | 12 +-- experiments/minif2f/model/__init__.py | 0 .../minif2f/model}/gen_tactic.py | 0 .../minif2f/model/llm_agent.py | 2 +- poetry.lock | 2 +- pyproject.toml | 8 +- 12 files changed, 115 insertions(+), 11 deletions(-) create mode 100644 experiments/minif2f/MiniF2F/.gitignore create mode 100644 experiments/minif2f/MiniF2F/MiniF2F.lean create mode 100644 experiments/minif2f/MiniF2F/lake-manifest.json create mode 100644 experiments/minif2f/MiniF2F/lakefile.lean create mode 120000 experiments/minif2f/MiniF2F/lean-toolchain create mode 100644 experiments/minif2f/model/__init__.py rename {pantograph => experiments/minif2f/model}/gen_tactic.py (100%) rename pantograph/search_llm.py => experiments/minif2f/model/llm_agent.py (98%) diff --git a/experiments/minif2f/MiniF2F/.gitignore b/experiments/minif2f/MiniF2F/.gitignore new file mode 100644 index 0000000..1824d0c --- /dev/null +++ b/experiments/minif2f/MiniF2F/.gitignore @@ -0,0 +1,3 @@ +/build +/lakefile.olean +/lake-packages/* diff --git a/experiments/minif2f/MiniF2F/MiniF2F.lean b/experiments/minif2f/MiniF2F/MiniF2F.lean new file mode 100644 index 0000000..46005ca --- /dev/null +++ b/experiments/minif2f/MiniF2F/MiniF2F.lean @@ -0,0 +1,2 @@ +import Aesop +import Mathlib diff --git a/experiments/minif2f/MiniF2F/lake-manifest.json b/experiments/minif2f/MiniF2F/lake-manifest.json new file mode 100644 index 0000000..b0a5048 --- /dev/null +++ b/experiments/minif2f/MiniF2F/lake-manifest.json @@ -0,0 +1,75 @@ +{"version": "1.1.0", + "packagesDir": ".lake/packages", + "packages": + [{"url": "https://github.com/leanprover-community/batteries", + "type": "git", + "subDir": null, + "scope": "", + "rev": "2ead90d24b4fac3a05c9c4294daa39bd8686fb98", + "name": "batteries", + "manifestFile": "lake-manifest.json", + "inputRev": "main", + "inherited": true, + "configFile": "lakefile.lean"}, + {"url": "https://github.com/leanprover-community/aesop.git", + "type": "git", + "subDir": null, + "scope": "", + "rev": "a64fe24aa94e21404940e9217363a9a1ed9a33a6", + "name": "aesop", + "manifestFile": "lake-manifest.json", + "inputRev": "v4.10.0-rc1", + "inherited": false, + "configFile": "lakefile.toml"}, + {"url": "https://github.com/leanprover-community/quote4", + "type": "git", + "subDir": null, + "scope": "leanprover-community", + "rev": "a7bfa63f5dddbcab2d4e0569c4cac74b2585e2c6", + "name": "Qq", + "manifestFile": "lake-manifest.json", + "inputRev": "master", + "inherited": true, + "configFile": "lakefile.lean"}, + {"url": "https://github.com/leanprover-community/ProofWidgets4", + "type": "git", + "subDir": null, + "scope": "leanprover-community", + "rev": "d1b33202c3a29a079f292de65ea438648123b635", + "name": "proofwidgets", + "manifestFile": "lake-manifest.json", + "inputRev": "v0.0.39", + "inherited": true, + "configFile": "lakefile.lean"}, + {"url": "https://github.com/leanprover/lean4-cli", + "type": "git", + "subDir": null, + "scope": "", + "rev": "a11566029bd9ec4f68a65394e8c3ff1af74c1a29", + "name": "Cli", + "manifestFile": "lake-manifest.json", + "inputRev": "main", + "inherited": true, + "configFile": "lakefile.lean"}, + {"url": "https://github.com/leanprover-community/import-graph", + "type": "git", + "subDir": null, + "scope": "leanprover-community", + "rev": "d366a602cc4a325a6f9db3a3991dfa6d6cf409c5", + "name": "importGraph", + "manifestFile": "lake-manifest.json", + "inputRev": "main", + "inherited": true, + "configFile": "lakefile.toml"}, + {"url": "https://github.com/leanprover-community/mathlib4.git", + "type": "git", + "subDir": null, + "scope": "", + "rev": "f5c3f06aa7f6d6c221786d2890c345a00e6341f8", + "name": "mathlib", + "manifestFile": "lake-manifest.json", + "inputRev": "v4.10.0-rc1", + "inherited": false, + "configFile": "lakefile.lean"}], + "name": "MiniF2F", + "lakeDir": ".lake"} diff --git a/experiments/minif2f/MiniF2F/lakefile.lean b/experiments/minif2f/MiniF2F/lakefile.lean new file mode 100644 index 0000000..f5ec5eb --- /dev/null +++ b/experiments/minif2f/MiniF2F/lakefile.lean @@ -0,0 +1,12 @@ +import Lake +open Lake DSL + +require aesop from git + "https://github.com/leanprover-community/aesop.git" @ "v4.10.0-rc1" +require mathlib from git + "https://github.com/leanprover-community/mathlib4.git" @ "v4.10.0-rc1" + +package MiniF2F + +@[default_target] +lean_lib MiniF2F diff --git a/experiments/minif2f/MiniF2F/lean-toolchain b/experiments/minif2f/MiniF2F/lean-toolchain new file mode 120000 index 0000000..b494e97 --- /dev/null +++ b/experiments/minif2f/MiniF2F/lean-toolchain @@ -0,0 +1 @@ +../../../src/lean-toolchain \ No newline at end of file diff --git a/experiments/minif2f/README.md b/experiments/minif2f/README.md index 0d93665..ad72c20 100644 --- a/experiments/minif2f/README.md +++ b/experiments/minif2f/README.md @@ -5,3 +5,12 @@ This is an experiment on running a LLM prover on miniF2F data. Run with ```sh python3 experiments/minif2f/main.py [--dry-run] ``` + +## Developing + +Run unit tests with + +``` sh +python3 -m model.{llm_agent,gen_tactic} +``` + diff --git a/experiments/minif2f/main.py b/experiments/minif2f/main.py index b4c41d4..16e56f1 100755 --- a/experiments/minif2f/main.py +++ b/experiments/minif2f/main.py @@ -5,15 +5,17 @@ from typing import Optional from pathlib import Path from pantograph.server import Server, ServerError from pantograph.search import SearchResult -from pantograph.search_llm import LLMAgent +from model.llm_agent import LLMAgent + +PATH_EXPERIMENT = Path(__file__).parent.resolve() def get_project_and_lean_path(): - cwd = Path(__file__).parent.resolve() / 'Example' + cwd = PATH_EXPERIMENT / 'MiniF2F' p = subprocess.check_output(['lake', 'env', 'printenv', 'LEAN_PATH'], cwd=cwd) return cwd, p def read_test_data(use_valid: bool): - jsonl_path = Path(__file__).parent / ('valid.jsonl' if use_valid else 'test.jsonl') + jsonl_path = PATH_EXPERIMENT / ('valid.jsonl' if use_valid else 'test.jsonl') with open(jsonl_path, 'r') as f: return [json.loads(l) for l in list(f)] @@ -44,7 +46,7 @@ def output_file_name(datum, use_hammer: bool, use_llm: bool): folder += '-hammer' if use_llm: folder += '-llm' - folder = Path(__file__).parent / folder + folder = PATH_EXPERIMENT / folder folder.mkdir(exist_ok=True, parents=True) return folder / f"{name}.json" @@ -65,7 +67,7 @@ def run_eval(args): if file_name.is_file(): print(f"Skipping {datum['id']}") continue - server = Server(imports=["Example"], project_path=project_path, lean_path=lean_path, options=["maxHeartbeats=0"]) + server = Server(imports=["MiniF2F"], project_path=project_path, lean_path=lean_path, options=["maxHeartbeats=0"]) agent = LLMAgent(server, use_hammer=args.use_hammer, use_llm=args.use_llm) result = try_test_data(server, agent, datum, max_steps=args.max_steps, max_trials_per_goal=args.max_trials_per_goal) if result is None: diff --git a/experiments/minif2f/model/__init__.py b/experiments/minif2f/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pantograph/gen_tactic.py b/experiments/minif2f/model/gen_tactic.py similarity index 100% rename from pantograph/gen_tactic.py rename to experiments/minif2f/model/gen_tactic.py diff --git a/pantograph/search_llm.py b/experiments/minif2f/model/llm_agent.py similarity index 98% rename from pantograph/search_llm.py rename to experiments/minif2f/model/llm_agent.py index c98ad23..3105d69 100644 --- a/pantograph/search_llm.py +++ b/experiments/minif2f/model/llm_agent.py @@ -3,7 +3,7 @@ import collections, unittest from pantograph.search import Agent from pantograph.server import Server, TacticFailure, ServerError from pantograph.expr import Expr, Tactic, GoalState -from pantograph.gen_tactic import LEAN4_REWRITE, select_tactic +from .gen_tactic import LEAN4_REWRITE, select_tactic import sglang as sgl class LLMAgent(Agent): diff --git a/poetry.lock b/poetry.lock index 179e415..59bc761 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3250,4 +3250,4 @@ test = ["websockets"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "686e7f1af124ef2404bc6b46677850c581e8e74f3cab51992fac8e8578f88a3a" +content-hash = "1ce8e928cff885e8c66d9c353e982e91e84ae84c91e96860aa3ca5a885bb0d2e" diff --git a/pyproject.toml b/pyproject.toml index d674e24..9a6fc33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,12 +7,8 @@ license = "GPL-3" readme = "README.md" [tool.poetry.dependencies] -# vllm = "0.4.1" -numpy = "^1.26.4" pexpect = "^4.9.0" python = "^3.10" -sglang = "^0.1.16" -torch = "2.2.1" [tool.poetry.build] generate-setup-file = false @@ -22,10 +18,14 @@ script = "build.py" # Experiment related dependencies here to not clutter the main project dependencies. fire = "0.6.0" notebook = "^7.2.1" +numpy = "^1.26.4" openai = "^1.31.0" +sglang = "^0.1.16" tenacity = "8.3.0" tiktoken = "^0.7.0" +torch = "2.2.1" wandb = "0.17.0" +# vllm = "0.4.1" [build-system] requires = ["poetry-core"] From b44036310515a41f8545a2815b20d1e6a9656f20 Mon Sep 17 00:00:00 2001 From: Leni Aniva Date: Thu, 3 Oct 2024 12:58:39 -0700 Subject: [PATCH 2/7] fix: Skip the commented out test cases --- experiments/minif2f/main.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/experiments/minif2f/main.py b/experiments/minif2f/main.py index 16e56f1..4186348 100755 --- a/experiments/minif2f/main.py +++ b/experiments/minif2f/main.py @@ -25,7 +25,12 @@ def try_test_data(server, agent, entry: dict, max_steps: int, max_trials_per_goa informal_stmt = entry["informal_stmt"] informal_proof = entry["informal_proof"] - goal_state, = server.load_sorry(command) + goal_states = server.load_sorry(command) + + if len(goal_states) == 0: + return None + + goal_state, = goal_states try: return agent.search( server=server, @@ -67,7 +72,7 @@ def run_eval(args): if file_name.is_file(): print(f"Skipping {datum['id']}") continue - server = Server(imports=["MiniF2F"], project_path=project_path, lean_path=lean_path, options=["maxHeartbeats=0"]) + server = Server(imports=["MiniF2F"], project_path=project_path, lean_path=lean_path, core_options=["maxHeartbeats=0"]) agent = LLMAgent(server, use_hammer=args.use_hammer, use_llm=args.use_llm) result = try_test_data(server, agent, datum, max_steps=args.max_steps, max_trials_per_goal=args.max_trials_per_goal) if result is None: From 20f3011eb43f3961427a2055fcaa08b5099dda4f Mon Sep 17 00:00:00 2001 From: Leni Aniva Date: Thu, 3 Oct 2024 15:45:14 -0700 Subject: [PATCH 3/7] doc: Improve error message --- experiments/minif2f/.gitignore | 1 + experiments/minif2f/README.md | 7 +++++-- experiments/minif2f/main.py | 3 ++- pantograph/server.py | 18 ++++++++++++------ 4 files changed, 20 insertions(+), 9 deletions(-) create mode 100644 experiments/minif2f/.gitignore diff --git a/experiments/minif2f/.gitignore b/experiments/minif2f/.gitignore new file mode 100644 index 0000000..55ac6ac --- /dev/null +++ b/experiments/minif2f/.gitignore @@ -0,0 +1 @@ +/output* diff --git a/experiments/minif2f/README.md b/experiments/minif2f/README.md index ad72c20..560eb4d 100644 --- a/experiments/minif2f/README.md +++ b/experiments/minif2f/README.md @@ -1,11 +1,14 @@ # MiniF2F -This is an experiment on running a LLM prover on miniF2F data. Run with +This is an experiment on running a LLM prover on miniF2F data. Build the project +`MiniF2F` with `lake build`, and run with ```sh -python3 experiments/minif2f/main.py [--dry-run] +python3 experiments/minif2f/main.py [--dry-run] [--use-llm] ``` +Read the help message carefully. + ## Developing Run unit tests with diff --git a/experiments/minif2f/main.py b/experiments/minif2f/main.py index 4186348..ea1fd50 100755 --- a/experiments/minif2f/main.py +++ b/experiments/minif2f/main.py @@ -72,9 +72,10 @@ def run_eval(args): if file_name.is_file(): print(f"Skipping {datum['id']}") continue - server = Server(imports=["MiniF2F"], project_path=project_path, lean_path=lean_path, core_options=["maxHeartbeats=0"]) + server = Server(imports=["MiniF2F"], project_path=project_path, lean_path=lean_path) agent = LLMAgent(server, use_hammer=args.use_hammer, use_llm=args.use_llm) result = try_test_data(server, agent, datum, max_steps=args.max_steps, max_trials_per_goal=args.max_trials_per_goal) + #server.gc() if result is None: with open(placeholder_file_name, 'w') as f: json.dump({ 'id': datum['id'] }, f) diff --git a/pantograph/server.py b/pantograph/server.py index e0f6870..a94b593 100644 --- a/pantograph/server.py +++ b/pantograph/server.py @@ -28,8 +28,8 @@ class Server: # Options for executing the REPL. # Set `{ "automaticMode" : False }` to handle resumption by yourself. options={}, - core_options=[], - timeout=20, + core_options=["maxHeartbeats=0"], + timeout=60, maxread=1000000): """ timeout: Amount of time to wait for execution @@ -86,7 +86,10 @@ class Server: self.proc.sendline(f"{cmd} {s}") try: line = self.proc.readline() - return json.loads(line) + try: + return json.loads(line) + except Exception as e: + raise ServerError(f"Cannot decode: {line}") from e except pexpect.exceptions.TIMEOUT as exc: raise exc @@ -96,9 +99,12 @@ class Server: Must be called periodically. """ - if self.to_remove_goal_states: - self.run('goal.delete', {'stateIds': self.to_remove_goal_states}) - self.to_remove_goal_states.clear() + if not self.to_remove_goal_states: + return + result = self.run('goal.delete', {'stateIds': self.to_remove_goal_states}) + self.to_remove_goal_states.clear() + if "error" in result: + raise ServerError(result["desc"]) def expr_type(self, expr: Expr) -> Expr: """ From 2fae5e97f1126bf103df4e25c5ee349c69c60ac4 Mon Sep 17 00:00:00 2001 From: Leni Aniva Date: Fri, 4 Oct 2024 17:55:32 -0700 Subject: [PATCH 4/7] feat: Concise prompts and unhygienic mode --- experiments/minif2f/main.py | 27 +++++++++---- experiments/minif2f/model/gen_tactic.py | 54 ++++++++++++++++--------- experiments/minif2f/model/llm_agent.py | 27 ++++++++++--- experiments/minif2f/model/options.py | 2 + pantograph/search.py | 11 ++++- pantograph/server.py | 4 +- poetry.lock | 2 +- pyproject.toml | 1 + 8 files changed, 91 insertions(+), 37 deletions(-) create mode 100644 experiments/minif2f/model/options.py diff --git a/experiments/minif2f/main.py b/experiments/minif2f/main.py index ea1fd50..296c661 100755 --- a/experiments/minif2f/main.py +++ b/experiments/minif2f/main.py @@ -3,9 +3,10 @@ import subprocess, json, argparse from typing import Optional from pathlib import Path -from pantograph.server import Server, ServerError +from pantograph.server import Server, ServerError, DEFAULT_CORE_OPTIONS from pantograph.search import SearchResult from model.llm_agent import LLMAgent +from model.options import CORE_OPTIONS PATH_EXPERIMENT = Path(__file__).parent.resolve() @@ -72,8 +73,18 @@ def run_eval(args): if file_name.is_file(): print(f"Skipping {datum['id']}") continue - server = Server(imports=["MiniF2F"], project_path=project_path, lean_path=lean_path) - agent = LLMAgent(server, use_hammer=args.use_hammer, use_llm=args.use_llm) + server = Server( + imports=["Mathlib", "Aesop"], + project_path=project_path, + lean_path=lean_path, + core_options=CORE_OPTIONS, + ) + agent = LLMAgent( + server, + use_hammer=args.use_hammer, + use_llm=args.use_llm, + feedback_turns=args.feedback_turns, + ) result = try_test_data(server, agent, datum, max_steps=args.max_steps, max_trials_per_goal=args.max_trials_per_goal) #server.gc() if result is None: @@ -87,8 +98,9 @@ def run_eval(args): if __name__ == '__main__': parser = argparse.ArgumentParser( - prog='MiniF2F Search', - description='Executes LLM on MiniF2F Search') + prog='MiniF2F Search', + description='Executes LLM on MiniF2F Search', + ) parser.add_argument('--use-hammer', action='store_true') parser.add_argument( '--dry-run', @@ -96,8 +108,9 @@ if __name__ == '__main__': help="List the data used, but don't run") parser.add_argument('--validation', action='store_true') parser.add_argument('--use-llm', action='store_true') - parser.add_argument('-s', '--max-steps', default=50) - parser.add_argument('-t', '--max-trials-per-goal', default=2) + parser.add_argument('--max-steps', default=50) + parser.add_argument('--max-trials-per-goal', default=2) + parser.add_argument('--feedback-turns', default=2) args = parser.parse_args() if args.dry_run: diff --git a/experiments/minif2f/model/gen_tactic.py b/experiments/minif2f/model/gen_tactic.py index a1c7ee8..1689d28 100644 --- a/experiments/minif2f/model/gen_tactic.py +++ b/experiments/minif2f/model/gen_tactic.py @@ -1,7 +1,12 @@ +""" +Tactic generation functions for the LLM agent +""" from pantograph.server import Server, ServerError, TacticFailure from pantograph.expr import Variable, Goal, TacticCalc -import unittest import sglang as sgl +from termcolor import colored +import unittest +from .options import CORE_OPTIONS LEAN4_INTRO = '''/-- A sequence `u` of real numbers converges to `l` if `∀ ε > 0, ∃ N, ∀ n ≥ N, |u_n - l| ≤ ε`. This condition will be spelled `seq_limit u l`. -/ @@ -78,6 +83,8 @@ example (n : Nat) (h : n = 0) (t : Tuple α n) : Tuple α 0 := by exact t ''' +PREFIX_CURRENT_GOAL = "The current goal: " + @sgl.function def multi_turn_question(s, question_1, question_2): s += sgl.system("You are a helpful assistant.") @@ -88,34 +95,42 @@ def multi_turn_question(s, question_1, question_2): @sgl.function -def select_tactic(s, server, state, goal_id,informal_stmt="", informal_proof="", feedback_turns = 5): - +def select_tactic( + s, server, state, goal_id, + informal_stmt: str = "", informal_proof: str = "", + feedback_turns: int = 5): + s += sgl.system("You are an expert in Lean. Choose the next ONE tactic to run given the current proof state and goals.") s += sgl.user(LEAN4_REWRITE) - s += sgl.user("The current proof state: GoalState(state_id=0, goals=[Goal(variables=[], target='∀ (a b: Nat), (b = 2) -> 1 + a + 1 = a + b', name=None, is_conversion=False)])") - s += sgl.assistant("```intros a b h```") - s += sgl.user("The current proof state: GoalState(state_id=1, goals=[Goal(variables=[Variable(t='Nat', v=None, name='a'), Variable(t='Nat', v=None, name='b'), Variable(t='b = 2', v=None, name='h')], target='1 + a + 1 = a + b', name=None, is_conversion=False)])") - s += sgl.assistant('TacticCalc("1 + a + 1 = a + 1 + 1")') + #s += sgl.user("The current proof state: GoalState(state_id=0, goals=[Goal(variables=[], target='∀ (a b: Nat), (b = 2) -> 1 + a + 1 = a + b', name=None, is_conversion=False)])") + #s += sgl.assistant("```intros a b h```") + #s += sgl.user("The current proof state: GoalState(state_id=1, goals=[Goal(variables=[Variable(t='Nat', v=None, name='a'), Variable(t='Nat', v=None, name='b'), Variable(t='b = 2', v=None, name='h')], target='1 + a + 1 = a + b', name=None, is_conversion=False)])") + #s += sgl.assistant('TacticCalc("1 + a + 1 = a + 1 + 1")') + s += sgl.user(f"{PREFIX_CURRENT_GOAL}p : Prop\n⊢ ∀ (q: Prop), Or p q -> Or q p") + s += sgl.assistant('```\nintro q\n```') + s += sgl.user(f"{PREFIX_CURRENT_GOAL}a b c : Nat\n⊢ a + b + c = a + c + b") + s += sgl.assistant('```\nrw [Nat.add_assoc, Nat.add_comm b, ← Nat.add_assoc]\n```') if informal_stmt and informal_proof: - s += sgl.user("informal theorem statement: "+ informal_stmt) + s += sgl.user("informal theorem statement: " + informal_stmt) s += sgl.user("informal proof: " + informal_proof) - s += sgl.user("The current proof state: " + str(state) + "") + s += sgl.user(f"{PREFIX_CURRENT_GOAL}{state.goals[goal_id]}") for i in range(feedback_turns): with s.copy() as tmp: tmp += sgl.assistant(sgl.gen("tactic", max_tokens=64)) # print("==tmp===") # print(tmp["tactic"]) - tactic = extract_code_from_llm_output(tmp["tactic"]) - s += sgl.assistant("```"+tactic+"```") + tactic = extract_code_from_llm_output(tmp["tactic"]).strip() + s += sgl.assistant(f"```\n{tactic}\n```") success, new_state = apply_tactic(server, state, goal_id, tactic) # print("===execute===") # print(success, new_state ) if not success: + print(colored("[Tactic]", "red"), tactic) with s.user(): - s += "This answer got Lean compile error:\n" + str(new_state) + "\n" + s += f"This answer got a Lean compile error:\n{new_state}\n" s += "Please try again by taking the Lean compiler feedback." - else: + print(colored("[Tactic]", "green"), tactic) return tactic, new_state return None, None @@ -127,7 +142,7 @@ def apply_tactic(server, state, goal_id, tactic): except TacticFailure as e: return False, e return True, new_state - + def extract_code_from_llm_output(reply): i = reply.find("```lean") if i != -1: @@ -149,7 +164,7 @@ class TestServerSGL(unittest.TestCase): n_trails = 5 sgl.set_default_backend(sgl.OpenAI("gpt-4")) - server = Server() + server = Server(core_options=CORE_OPTIONS) state0 = server.goal_start("∀ (a b: Nat), (b = 2) -> 1 + a + 1 = a + b") print("==========state0============") print(state0) @@ -187,7 +202,7 @@ class TestServerSGL(unittest.TestCase): print("\n-- new state --\n", state3) break - + except ServerError as e: print(f"server error: {e}") continue @@ -207,14 +222,14 @@ class TestServerSGL(unittest.TestCase): print("\n-- new state --\n", state4) break - + except ServerError as e: print(f"server error: {e}") continue state4 = server.goal_tactic(state3, goal_id=0, tactic="rw [Nat.add_assoc]") print("==========state4============") - print(state4) + print(state4) self.assertTrue(state4.is_solved) @@ -232,8 +247,7 @@ class TestServerSGL(unittest.TestCase): print("\n-- answer_1 --\n", state["answer_1"]) - + if __name__ == '__main__': unittest.main() - diff --git a/experiments/minif2f/model/llm_agent.py b/experiments/minif2f/model/llm_agent.py index 3105d69..af09302 100644 --- a/experiments/minif2f/model/llm_agent.py +++ b/experiments/minif2f/model/llm_agent.py @@ -4,6 +4,7 @@ from pantograph.search import Agent from pantograph.server import Server, TacticFailure, ServerError from pantograph.expr import Expr, Tactic, GoalState from .gen_tactic import LEAN4_REWRITE, select_tactic +from .options import CORE_OPTIONS import sglang as sgl class LLMAgent(Agent): @@ -12,7 +13,9 @@ class LLMAgent(Agent): """ def __init__(self, server, - use_hammer=True, use_llm=True): + use_hammer=True, + use_llm=True, + feedback_turns=3): super().__init__() self.n_trials = 5 self.server = server @@ -24,17 +27,23 @@ class LLMAgent(Agent): self.use_hammer = use_hammer self.use_llm = use_llm + self.feedback_turns = feedback_turns if use_hammer: self.tactics = [ "aesop", - #"simp", + "simp", #"rfl", #"decide", ] else: self.tactics = [] - def next_tactic(self, state: GoalState, goal_id: int, informal_stmt:str="", informal_proof:str="") -> Optional[Tactic]: + def next_tactic( + self, + state: GoalState, + goal_id: int, + informal_stmt: str = "", + informal_proof: str = "") -> Optional[Tactic]: key = (state.state_id, goal_id) i = self.goal_tactic_id_map[key] @@ -46,7 +55,13 @@ class LLMAgent(Agent): new_state = None for ii in range(self.n_trials): print(f"===============trail {str(ii)}============") - s = select_tactic.run(server = self.server, state=state, goal_id = goal_id, informal_stmt=informal_stmt, informal_proof=informal_proof) + s = select_tactic.run( + server=self.server, + state=state, + goal_id=goal_id, + informal_stmt=informal_stmt, + informal_proof=informal_proof, + feedback_turns=self.feedback_turns) tactic, new_state = s.ret_value for m in s.messages(): print(m["role"], ":", m["content"]) @@ -78,7 +93,7 @@ class TestSearch(unittest.TestCase): def test_solve(self): - server = Server() + server = Server(core_options=CORE_OPTIONS) agent = LLMAgent(server, use_hammer=False) goal_state = server.goal_start("∀ (p q: Prop), p -> p") flag = agent.search(server=server, goal_state=goal_state, verbose=True) @@ -86,7 +101,7 @@ class TestSearch(unittest.TestCase): self.assertTrue(flag) def test_solve_big(self): - server = Server() + server = Server(core_options=CORE_OPTIONS) agent = LLMAgent(server, use_hammer=False) goal_state = server.goal_start("∀ (p q: Prop), Or p q -> Or q p") flag = agent.search(server=server, goal_state=goal_state, verbose=True) diff --git a/experiments/minif2f/model/options.py b/experiments/minif2f/model/options.py new file mode 100644 index 0000000..ef5755e --- /dev/null +++ b/experiments/minif2f/model/options.py @@ -0,0 +1,2 @@ +from pantograph.server import DEFAULT_CORE_OPTIONS +CORE_OPTIONS = DEFAULT_CORE_OPTIONS + ["tactic.hygienic=false"] diff --git a/pantograph/search.py b/pantograph/search.py index 00b4284..a166980 100644 --- a/pantograph/search.py +++ b/pantograph/search.py @@ -1,3 +1,4 @@ +from abc import abstractmethod from dataclasses import dataclass from typing import Optional import collections, unittest @@ -44,7 +45,9 @@ class Agent: """ An agent interface for proof search """ + tactic_feedback: Optional[str] = None + @abstractmethod def next_tactic( self, state: GoalState, @@ -54,14 +57,15 @@ class Agent: """ Implement this function to generate the next tactic for a goal """ - return None + @abstractmethod def guidance(self, state: GoalState) -> list[float]: """ Return a list of priorities determining which goal should be searched first. This will not be called on states with one or zero goals. """ return [0.0 for _ in state.goals] + @abstractmethod def reset(self): """ Called after search @@ -116,11 +120,13 @@ class Agent: if verbose: print(f"Next tactic: {tactic}") if not tactic: + # resets the feedback + self.tactic_feedback = None # pop the current state and continue to the next search_stack.pop(-1) if not search_stack: if verbose: - print("Tactic list has been exhausted") + print("Search stack has been exhausted") self.reset() return SearchResult(success=False, steps=i_step) continue @@ -147,6 +153,7 @@ class Agent: except TacticFailure as t: if verbose: print(f"Tactic failed: {t}") + self.tactic_feedback = str(t) # try the next tactic. this one failed if verbose: diff --git a/pantograph/server.py b/pantograph/server.py index a94b593..e3770cc 100644 --- a/pantograph/server.py +++ b/pantograph/server.py @@ -19,6 +19,8 @@ class TacticFailure(Exception): class ServerError(Exception): pass +DEFAULT_CORE_OPTIONS=["maxHeartbeats=0", "maxRecDepth=10000"] + class Server: def __init__(self, @@ -28,7 +30,7 @@ class Server: # Options for executing the REPL. # Set `{ "automaticMode" : False }` to handle resumption by yourself. options={}, - core_options=["maxHeartbeats=0"], + core_options=DEFAULT_CORE_OPTIONS, timeout=60, maxread=1000000): """ diff --git a/poetry.lock b/poetry.lock index 59bc761..317d065 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3250,4 +3250,4 @@ test = ["websockets"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "1ce8e928cff885e8c66d9c353e982e91e84ae84c91e96860aa3ca5a885bb0d2e" +content-hash = "b198bb707b86539e6c8edfe2b7377d47387aaaf053bb68b135ccd15361736030" diff --git a/pyproject.toml b/pyproject.toml index 9a6fc33..4a51591 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ tiktoken = "^0.7.0" torch = "2.2.1" wandb = "0.17.0" # vllm = "0.4.1" +termcolor = "^2.4.0" [build-system] requires = ["poetry-core"] From 542784caa2cdcd63cf6306acbd8731211e09dd45 Mon Sep 17 00:00:00 2001 From: Leni Aniva Date: Fri, 4 Oct 2024 18:01:48 -0700 Subject: [PATCH 5/7] fix: Trailing comma in reply, remove simp fallback --- experiments/minif2f/main.py | 4 +++- experiments/minif2f/model/gen_tactic.py | 8 +++++++- experiments/minif2f/model/llm_agent.py | 2 +- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/experiments/minif2f/main.py b/experiments/minif2f/main.py index 296c661..559443f 100755 --- a/experiments/minif2f/main.py +++ b/experiments/minif2f/main.py @@ -3,6 +3,7 @@ import subprocess, json, argparse from typing import Optional from pathlib import Path +from termcolor import colored from pantograph.server import Server, ServerError, DEFAULT_CORE_OPTIONS from pantograph.search import SearchResult from model.llm_agent import LLMAgent @@ -71,8 +72,9 @@ def run_eval(args): file_name = output_file_name(datum, args.use_hammer, args.use_llm) placeholder_file_name = file_name.with_suffix('.placeholder') if file_name.is_file(): - print(f"Skipping {datum['id']}") + print(colored(f"Skipping {datum['id']}", "green")) continue + print(colored(f"Evaluating on {datum['id']} ...", "blue")) server = Server( imports=["Mathlib", "Aesop"], project_path=project_path, diff --git a/experiments/minif2f/model/gen_tactic.py b/experiments/minif2f/model/gen_tactic.py index 1689d28..8c115bc 100644 --- a/experiments/minif2f/model/gen_tactic.py +++ b/experiments/minif2f/model/gen_tactic.py @@ -119,7 +119,7 @@ def select_tactic( tmp += sgl.assistant(sgl.gen("tactic", max_tokens=64)) # print("==tmp===") # print(tmp["tactic"]) - tactic = extract_code_from_llm_output(tmp["tactic"]).strip() + tactic = postprocess_reply(extract_code_from_llm_output(tmp["tactic"])) s += sgl.assistant(f"```\n{tactic}\n```") success, new_state = apply_tactic(server, state, goal_id, tactic) # print("===execute===") @@ -158,6 +158,12 @@ def extract_code_from_llm_output(reply): return reply return reply +def postprocess_reply(reply): + reply = reply.strip() + if reply and reply[-1] == ",": + reply = reply[:-1] + return reply + class TestServerSGL(unittest.TestCase): def test_conv_calc_sgl(self): diff --git a/experiments/minif2f/model/llm_agent.py b/experiments/minif2f/model/llm_agent.py index af09302..d662a8b 100644 --- a/experiments/minif2f/model/llm_agent.py +++ b/experiments/minif2f/model/llm_agent.py @@ -31,7 +31,7 @@ class LLMAgent(Agent): if use_hammer: self.tactics = [ "aesop", - "simp", + #"simp", #"rfl", #"decide", ] From 5b176795b2d018c754e8eac95da65daca7ce25dc Mon Sep 17 00:00:00 2001 From: Leni Aniva Date: Fri, 4 Oct 2024 18:04:10 -0700 Subject: [PATCH 6/7] doc: Diagnostics info at result --- experiments/minif2f/main.py | 1 + 1 file changed, 1 insertion(+) diff --git a/experiments/minif2f/main.py b/experiments/minif2f/main.py index 559443f..a81ae2b 100755 --- a/experiments/minif2f/main.py +++ b/experiments/minif2f/main.py @@ -88,6 +88,7 @@ def run_eval(args): feedback_turns=args.feedback_turns, ) result = try_test_data(server, agent, datum, max_steps=args.max_steps, max_trials_per_goal=args.max_trials_per_goal) + print(colored(f"Result on {datum['id']}: {result}", "blue")) #server.gc() if result is None: with open(placeholder_file_name, 'w') as f: From 82d9f9200e4ad8ba3bd055225e952609c8983f0f Mon Sep 17 00:00:00 2001 From: Leni Aniva Date: Fri, 4 Oct 2024 18:45:13 -0700 Subject: [PATCH 7/7] refactor: Pass in `informal_{stmt,proof}` directly --- experiments/minif2f/main.py | 14 +++++++++----- experiments/minif2f/model/gen_tactic.py | 3 ++- experiments/minif2f/model/llm_agent.py | 10 ++++++---- 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/experiments/minif2f/main.py b/experiments/minif2f/main.py index a81ae2b..1ddf467 100755 --- a/experiments/minif2f/main.py +++ b/experiments/minif2f/main.py @@ -24,8 +24,8 @@ def read_test_data(use_valid: bool): def try_test_data(server, agent, entry: dict, max_steps: int, max_trials_per_goal: int) -> Optional[SearchResult]: command = entry["formal_statement"] print(command) - informal_stmt = entry["informal_stmt"] - informal_proof = entry["informal_proof"] + agent.informal_stmt = entry["informal_stmt"] + agent.informal_proof = entry["informal_proof"] goal_states = server.load_sorry(command) @@ -37,8 +37,6 @@ def try_test_data(server, agent, entry: dict, max_steps: int, max_trials_per_goa return agent.search( server=server, goal_state=goal_state, - informal_stmt=informal_stmt, - informal_proof=informal_proof, verbose=True, max_steps=max_steps, max_trials_per_goal=max_trials_per_goal @@ -87,7 +85,13 @@ def run_eval(args): use_llm=args.use_llm, feedback_turns=args.feedback_turns, ) - result = try_test_data(server, agent, datum, max_steps=args.max_steps, max_trials_per_goal=args.max_trials_per_goal) + result = try_test_data( + server, + agent, + datum, + max_steps=args.max_steps, + max_trials_per_goal=args.max_trials_per_goal, + ) print(colored(f"Result on {datum['id']}: {result}", "blue")) #server.gc() if result is None: diff --git a/experiments/minif2f/model/gen_tactic.py b/experiments/minif2f/model/gen_tactic.py index 8c115bc..d0ed476 100644 --- a/experiments/minif2f/model/gen_tactic.py +++ b/experiments/minif2f/model/gen_tactic.py @@ -97,7 +97,8 @@ def multi_turn_question(s, question_1, question_2): @sgl.function def select_tactic( s, server, state, goal_id, - informal_stmt: str = "", informal_proof: str = "", + informal_stmt: str = "", + informal_proof: str = "", feedback_turns: int = 5): s += sgl.system("You are an expert in Lean. Choose the next ONE tactic to run given the current proof state and goals.") diff --git a/experiments/minif2f/model/llm_agent.py b/experiments/minif2f/model/llm_agent.py index d662a8b..9c069a8 100644 --- a/experiments/minif2f/model/llm_agent.py +++ b/experiments/minif2f/model/llm_agent.py @@ -38,12 +38,14 @@ class LLMAgent(Agent): else: self.tactics = [] + self.informal_stmt = "" + self.informal_proof = "" + def next_tactic( self, state: GoalState, goal_id: int, - informal_stmt: str = "", - informal_proof: str = "") -> Optional[Tactic]: + ) -> Optional[Tactic]: key = (state.state_id, goal_id) i = self.goal_tactic_id_map[key] @@ -59,8 +61,8 @@ class LLMAgent(Agent): server=self.server, state=state, goal_id=goal_id, - informal_stmt=informal_stmt, - informal_proof=informal_proof, + informal_stmt=self.informal_stmt, + informal_proof=self.informal_proof, feedback_turns=self.feedback_turns) tactic, new_state = s.ret_value for m in s.messages():