diff --git a/experiments/dsp/main.py b/experiments/dsp/main.py index 1e64629..92ddcf0 100644 --- a/experiments/dsp/main.py +++ b/experiments/dsp/main.py @@ -10,6 +10,7 @@ from tenacity import retry, stop_after_attempt, wait_exponential from pantograph import Server from solve.prompts import ( + extract_lean_code, SYSTEM_PROMPT_DRAFT_V0, SYSTEM_PROMPT_SKETCH_V0, prompt_draft_template_lean4_v0, @@ -58,7 +59,7 @@ class OpenAI_DSP_Engine(Engine): verbose_init: bool = True, ): super().__init__() - print(f'{api_key=}, {base_url=}') if verbose_init else None + print(f'{base_url=}') if verbose_init else None if not ('gpt-4-' in model or 'gpt-3.5-' in model or 'gpt-4o' in model): @@ -161,7 +162,8 @@ def sketch( return sketches, x_fl_problem def prove( - eng, + eng: Engine, + server: Server, fl_prob: str, fl_sketch: list[str], ): @@ -173,8 +175,11 @@ def prove( fl_sketch --> Lean4 Form Sketch --> have x have ha """ - print(f"fl_prob={fl_prob}") - print(f"fl_sketch={fl_sketch}") + # If this throws index out of bound errors it means the source doesn't contain walled off Lean sections. + lean_code, = [extract_lean_code(sketch)[0] for sketch in fl_sketch] + state, = server.load_sorry(lean_code) + + print(state) raise RuntimeError("Not implemented") # -- Prove correct: bool = False @@ -195,7 +200,7 @@ def single_proof_search_dsp_lean( z_fl_pred_sketches, x_fl_prob = sketch(eng, data_pt, y_nl_pred_drafts) # -- Prove: y_fl = prove(eng, x_fl_prob, z_fl_pred_sketches) - correct: bool = prove(eng, x_fl_prob, z_fl_pred_sketches) + correct: bool = prove(eng, server, x_fl_prob, z_fl_pred_sketches) # -- Return return correct @@ -212,6 +217,7 @@ def full_proof_search_dsp_lean( for data_pt in tqdm(eval_dataset, total=len(eval_dataset), desc='DSP proof loop per data point in benchmark.'): print(f'{data_pt=}') flag = single_proof_search_dsp_lean(eng, server, data_pt) + server.gc() return experiment_dir = Path(__file__).resolve().parent diff --git a/experiments/dsp/solve/prompts.py b/experiments/dsp/solve/prompts.py index 0ed6499..6c177ac 100644 --- a/experiments/dsp/solve/prompts.py +++ b/experiments/dsp/solve/prompts.py @@ -7,8 +7,7 @@ core part of data for prompt for dsp: "src_header_fl_problem": ..., #src_header_x*_fl "fl_header_sketch": ..., # hz_fl suggested header """ -import json -import sys +import json, sys, unittest from pathlib import Path from typing import Optional @@ -19,6 +18,8 @@ experiment_dir = Path(__file__).resolve().parent.parent default_path_2_examples = 'debug/toy_example1_dsp/dsp_debug5_sf/dsp_debug5_sf_train.json' +TOKEN_PLACEHOLDER = "" + # -- Prompt draft (P_draft) for Lean 4 """ Draft an informal solution similar to the one below. @@ -98,8 +99,8 @@ Formal:\n SYSTEM_PROMPT_SKETCH_V0 = 'You are an expert mathematician and an expert in the Lean 4 Proof Assistant.' STOP_TOKENS_SKETCH_V0: list[str] = ['Informal:', '(*### Problem', '###Solution', 'Formal:'] prompt_sketch_template_lean4_v0 = ("Translate the informal solution into a sketch in the " -"formal Lean 4 proof. Add in the formal sketch whenever possible. " -" will be used to call a automated theorem prover or tactic in Lean 4. " +f"formal Lean 4 proof. Add {TOKEN_PLACEHOLDER} in the formal sketch whenever possible. " +f"{TOKEN_PLACEHOLDER} will be used to call a automated theorem prover or tactic in Lean 4. " "Here are some examples:\n" ) def get_prompt_sketch_template_4_lean_v0( @@ -139,3 +140,54 @@ def get_prompt_sketch_template_4_lean_v0( print(prompt_sketch_template_4_lean) if verbose else None return prompt_sketch_template_4_lean prompt_sketch_template_lean4_v0 = get_prompt_sketch_template_4_lean_v0() + +WALL = "```" + + +def extract_lean_code( + sketch: str, + placeholder: str = TOKEN_PLACEHOLDER, + strip_imports: bool = True) -> list[str]: + lines = sketch.split("\n") + # find backtick markers ``` + lean_codes = [] + curr = [] + is_walled = False + is_walled_lean = False + for line in lines: + if not is_walled: + if line.rstrip() == f"{WALL}lean": + is_walled = True + is_walled_lean = True + elif line.startswith(WALL): + is_walled = True + is_walled_lean = False + continue + + if line.rstrip() == WALL: + if is_walled_lean: + code = "\n".join(curr) + "\n" + code = code.replace("ℕ", "Nat").replace(placeholder, "sorry") + lean_codes.append(code) + is_walled = False + is_walled_lean = False + continue + + if strip_imports and line.startswith("import "): + continue + curr.append(line) + + return lean_codes + + +class TestPrompts(unittest.TestCase): + + def test_extract_lean_code(self): + sketch = "```lean\nimport Mathlib.Data.Nat.Basic\nimport Aesop\n\ntheorem n_plus_zero : ∀ n : ℕ, n + 0 = n := by\n -- Consider any natural number n. We need to show that n + 0 = n.\n -- Use the fact that adding zero to any natural number does not change its value.\n have h_nat_add_zero: ∀ n : ℕ, n + 0 = n := \n -- Combine facts to close goal\n \n```" + codes = extract_lean_code(sketch) + self.assertEqual(codes, [ + "import Mathlib.Data.Nat.Basic\nimport Aesop\n\ntheorem n_plus_zero : ∀ n : Nat, n + 0 = n := by\n -- Consider any natural number n. We need to show that n + 0 = n.\n -- Use the fact that adding zero to any natural number does not change its value.\n have h_nat_add_zero: ∀ n : Nat, n + 0 = n := sorry\n -- Combine facts to close goal\n sorry\n" + ]) + +if __name__ == '__main__': + unittest.main()