feat: Extract Lean code sections from sketches
This commit is contained in:
parent
f1e996baae
commit
80a356c75c
|
@ -10,6 +10,7 @@ from tenacity import retry, stop_after_attempt, wait_exponential
|
|||
from pantograph import Server
|
||||
|
||||
from solve.prompts import (
|
||||
extract_lean_code,
|
||||
SYSTEM_PROMPT_DRAFT_V0,
|
||||
SYSTEM_PROMPT_SKETCH_V0,
|
||||
prompt_draft_template_lean4_v0,
|
||||
|
@ -58,7 +59,7 @@ class OpenAI_DSP_Engine(Engine):
|
|||
verbose_init: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
print(f'{api_key=}, {base_url=}') if verbose_init else None
|
||||
print(f'{base_url=}') if verbose_init else None
|
||||
|
||||
|
||||
if not ('gpt-4-' in model or 'gpt-3.5-' in model or 'gpt-4o' in model):
|
||||
|
@ -161,7 +162,8 @@ def sketch(
|
|||
return sketches, x_fl_problem
|
||||
|
||||
def prove(
|
||||
eng,
|
||||
eng: Engine,
|
||||
server: Server,
|
||||
fl_prob: str,
|
||||
fl_sketch: list[str],
|
||||
):
|
||||
|
@ -173,8 +175,11 @@ def prove(
|
|||
fl_sketch --> Lean4 Form Sketch --> have x have ha
|
||||
|
||||
"""
|
||||
print(f"fl_prob={fl_prob}")
|
||||
print(f"fl_sketch={fl_sketch}")
|
||||
# If this throws index out of bound errors it means the source doesn't contain walled off Lean sections.
|
||||
lean_code, = [extract_lean_code(sketch)[0] for sketch in fl_sketch]
|
||||
state, = server.load_sorry(lean_code)
|
||||
|
||||
print(state)
|
||||
raise RuntimeError("Not implemented")
|
||||
# -- Prove
|
||||
correct: bool = False
|
||||
|
@ -195,7 +200,7 @@ def single_proof_search_dsp_lean(
|
|||
z_fl_pred_sketches, x_fl_prob = sketch(eng, data_pt, y_nl_pred_drafts)
|
||||
|
||||
# -- Prove: y_fl = prove(eng, x_fl_prob, z_fl_pred_sketches)
|
||||
correct: bool = prove(eng, x_fl_prob, z_fl_pred_sketches)
|
||||
correct: bool = prove(eng, server, x_fl_prob, z_fl_pred_sketches)
|
||||
|
||||
# -- Return
|
||||
return correct
|
||||
|
@ -212,6 +217,7 @@ def full_proof_search_dsp_lean(
|
|||
for data_pt in tqdm(eval_dataset, total=len(eval_dataset), desc='DSP proof loop per data point in benchmark.'):
|
||||
print(f'{data_pt=}')
|
||||
flag = single_proof_search_dsp_lean(eng, server, data_pt)
|
||||
server.gc()
|
||||
return
|
||||
|
||||
experiment_dir = Path(__file__).resolve().parent
|
||||
|
|
|
@ -7,8 +7,7 @@ core part of data for prompt for dsp:
|
|||
"src_header_fl_problem": ..., #src_header_x*_fl
|
||||
"fl_header_sketch": ..., # hz_fl suggested header
|
||||
"""
|
||||
import json
|
||||
import sys
|
||||
import json, sys, unittest
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
@ -19,6 +18,8 @@ experiment_dir = Path(__file__).resolve().parent.parent
|
|||
|
||||
default_path_2_examples = 'debug/toy_example1_dsp/dsp_debug5_sf/dsp_debug5_sf_train.json'
|
||||
|
||||
TOKEN_PLACEHOLDER = "<TODO_PROOF_OR_HAMMER>"
|
||||
|
||||
# -- Prompt draft (P_draft) for Lean 4
|
||||
"""
|
||||
Draft an informal solution similar to the one below.
|
||||
|
@ -98,8 +99,8 @@ Formal:\n
|
|||
SYSTEM_PROMPT_SKETCH_V0 = 'You are an expert mathematician and an expert in the Lean 4 Proof Assistant.'
|
||||
STOP_TOKENS_SKETCH_V0: list[str] = ['Informal:', '(*### Problem', '###Solution', 'Formal:']
|
||||
prompt_sketch_template_lean4_v0 = ("Translate the informal solution into a sketch in the "
|
||||
"formal Lean 4 proof. Add <TODO_PROOF_OR_HAMMER> in the formal sketch whenever possible. "
|
||||
"<TODO_PROOF_OR_HAMMER> will be used to call a automated theorem prover or tactic in Lean 4. "
|
||||
f"formal Lean 4 proof. Add {TOKEN_PLACEHOLDER} in the formal sketch whenever possible. "
|
||||
f"{TOKEN_PLACEHOLDER} will be used to call a automated theorem prover or tactic in Lean 4. "
|
||||
"Here are some examples:\n"
|
||||
)
|
||||
def get_prompt_sketch_template_4_lean_v0(
|
||||
|
@ -139,3 +140,54 @@ def get_prompt_sketch_template_4_lean_v0(
|
|||
print(prompt_sketch_template_4_lean) if verbose else None
|
||||
return prompt_sketch_template_4_lean
|
||||
prompt_sketch_template_lean4_v0 = get_prompt_sketch_template_4_lean_v0()
|
||||
|
||||
WALL = "```"
|
||||
|
||||
|
||||
def extract_lean_code(
|
||||
sketch: str,
|
||||
placeholder: str = TOKEN_PLACEHOLDER,
|
||||
strip_imports: bool = True) -> list[str]:
|
||||
lines = sketch.split("\n")
|
||||
# find backtick markers ```
|
||||
lean_codes = []
|
||||
curr = []
|
||||
is_walled = False
|
||||
is_walled_lean = False
|
||||
for line in lines:
|
||||
if not is_walled:
|
||||
if line.rstrip() == f"{WALL}lean":
|
||||
is_walled = True
|
||||
is_walled_lean = True
|
||||
elif line.startswith(WALL):
|
||||
is_walled = True
|
||||
is_walled_lean = False
|
||||
continue
|
||||
|
||||
if line.rstrip() == WALL:
|
||||
if is_walled_lean:
|
||||
code = "\n".join(curr) + "\n"
|
||||
code = code.replace("ℕ", "Nat").replace(placeholder, "sorry")
|
||||
lean_codes.append(code)
|
||||
is_walled = False
|
||||
is_walled_lean = False
|
||||
continue
|
||||
|
||||
if strip_imports and line.startswith("import "):
|
||||
continue
|
||||
curr.append(line)
|
||||
|
||||
return lean_codes
|
||||
|
||||
|
||||
class TestPrompts(unittest.TestCase):
|
||||
|
||||
def test_extract_lean_code(self):
|
||||
sketch = "```lean\nimport Mathlib.Data.Nat.Basic\nimport Aesop\n\ntheorem n_plus_zero : ∀ n : ℕ, n + 0 = n := by\n -- Consider any natural number n. We need to show that n + 0 = n.\n -- Use the fact that adding zero to any natural number does not change its value.\n have h_nat_add_zero: ∀ n : ℕ, n + 0 = n := <TODO_PROOF_OR_HAMMER>\n -- Combine facts to close goal\n <TODO_PROOF_OR_HAMMER>\n```"
|
||||
codes = extract_lean_code(sketch)
|
||||
self.assertEqual(codes, [
|
||||
"import Mathlib.Data.Nat.Basic\nimport Aesop\n\ntheorem n_plus_zero : ∀ n : Nat, n + 0 = n := by\n -- Consider any natural number n. We need to show that n + 0 = n.\n -- Use the fact that adding zero to any natural number does not change its value.\n have h_nat_add_zero: ∀ n : Nat, n + 0 = n := sorry\n -- Combine facts to close goal\n sorry\n"
|
||||
])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
Loading…
Reference in New Issue