feat: Extract Lean code sections from sketches
This commit is contained in:
parent
f1e996baae
commit
80a356c75c
|
@ -10,6 +10,7 @@ from tenacity import retry, stop_after_attempt, wait_exponential
|
||||||
from pantograph import Server
|
from pantograph import Server
|
||||||
|
|
||||||
from solve.prompts import (
|
from solve.prompts import (
|
||||||
|
extract_lean_code,
|
||||||
SYSTEM_PROMPT_DRAFT_V0,
|
SYSTEM_PROMPT_DRAFT_V0,
|
||||||
SYSTEM_PROMPT_SKETCH_V0,
|
SYSTEM_PROMPT_SKETCH_V0,
|
||||||
prompt_draft_template_lean4_v0,
|
prompt_draft_template_lean4_v0,
|
||||||
|
@ -58,7 +59,7 @@ class OpenAI_DSP_Engine(Engine):
|
||||||
verbose_init: bool = True,
|
verbose_init: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
print(f'{api_key=}, {base_url=}') if verbose_init else None
|
print(f'{base_url=}') if verbose_init else None
|
||||||
|
|
||||||
|
|
||||||
if not ('gpt-4-' in model or 'gpt-3.5-' in model or 'gpt-4o' in model):
|
if not ('gpt-4-' in model or 'gpt-3.5-' in model or 'gpt-4o' in model):
|
||||||
|
@ -161,7 +162,8 @@ def sketch(
|
||||||
return sketches, x_fl_problem
|
return sketches, x_fl_problem
|
||||||
|
|
||||||
def prove(
|
def prove(
|
||||||
eng,
|
eng: Engine,
|
||||||
|
server: Server,
|
||||||
fl_prob: str,
|
fl_prob: str,
|
||||||
fl_sketch: list[str],
|
fl_sketch: list[str],
|
||||||
):
|
):
|
||||||
|
@ -173,8 +175,11 @@ def prove(
|
||||||
fl_sketch --> Lean4 Form Sketch --> have x have ha
|
fl_sketch --> Lean4 Form Sketch --> have x have ha
|
||||||
|
|
||||||
"""
|
"""
|
||||||
print(f"fl_prob={fl_prob}")
|
# If this throws index out of bound errors it means the source doesn't contain walled off Lean sections.
|
||||||
print(f"fl_sketch={fl_sketch}")
|
lean_code, = [extract_lean_code(sketch)[0] for sketch in fl_sketch]
|
||||||
|
state, = server.load_sorry(lean_code)
|
||||||
|
|
||||||
|
print(state)
|
||||||
raise RuntimeError("Not implemented")
|
raise RuntimeError("Not implemented")
|
||||||
# -- Prove
|
# -- Prove
|
||||||
correct: bool = False
|
correct: bool = False
|
||||||
|
@ -195,7 +200,7 @@ def single_proof_search_dsp_lean(
|
||||||
z_fl_pred_sketches, x_fl_prob = sketch(eng, data_pt, y_nl_pred_drafts)
|
z_fl_pred_sketches, x_fl_prob = sketch(eng, data_pt, y_nl_pred_drafts)
|
||||||
|
|
||||||
# -- Prove: y_fl = prove(eng, x_fl_prob, z_fl_pred_sketches)
|
# -- Prove: y_fl = prove(eng, x_fl_prob, z_fl_pred_sketches)
|
||||||
correct: bool = prove(eng, x_fl_prob, z_fl_pred_sketches)
|
correct: bool = prove(eng, server, x_fl_prob, z_fl_pred_sketches)
|
||||||
|
|
||||||
# -- Return
|
# -- Return
|
||||||
return correct
|
return correct
|
||||||
|
@ -212,6 +217,7 @@ def full_proof_search_dsp_lean(
|
||||||
for data_pt in tqdm(eval_dataset, total=len(eval_dataset), desc='DSP proof loop per data point in benchmark.'):
|
for data_pt in tqdm(eval_dataset, total=len(eval_dataset), desc='DSP proof loop per data point in benchmark.'):
|
||||||
print(f'{data_pt=}')
|
print(f'{data_pt=}')
|
||||||
flag = single_proof_search_dsp_lean(eng, server, data_pt)
|
flag = single_proof_search_dsp_lean(eng, server, data_pt)
|
||||||
|
server.gc()
|
||||||
return
|
return
|
||||||
|
|
||||||
experiment_dir = Path(__file__).resolve().parent
|
experiment_dir = Path(__file__).resolve().parent
|
||||||
|
|
|
@ -7,8 +7,7 @@ core part of data for prompt for dsp:
|
||||||
"src_header_fl_problem": ..., #src_header_x*_fl
|
"src_header_fl_problem": ..., #src_header_x*_fl
|
||||||
"fl_header_sketch": ..., # hz_fl suggested header
|
"fl_header_sketch": ..., # hz_fl suggested header
|
||||||
"""
|
"""
|
||||||
import json
|
import json, sys, unittest
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
@ -19,6 +18,8 @@ experiment_dir = Path(__file__).resolve().parent.parent
|
||||||
|
|
||||||
default_path_2_examples = 'debug/toy_example1_dsp/dsp_debug5_sf/dsp_debug5_sf_train.json'
|
default_path_2_examples = 'debug/toy_example1_dsp/dsp_debug5_sf/dsp_debug5_sf_train.json'
|
||||||
|
|
||||||
|
TOKEN_PLACEHOLDER = "<TODO_PROOF_OR_HAMMER>"
|
||||||
|
|
||||||
# -- Prompt draft (P_draft) for Lean 4
|
# -- Prompt draft (P_draft) for Lean 4
|
||||||
"""
|
"""
|
||||||
Draft an informal solution similar to the one below.
|
Draft an informal solution similar to the one below.
|
||||||
|
@ -98,8 +99,8 @@ Formal:\n
|
||||||
SYSTEM_PROMPT_SKETCH_V0 = 'You are an expert mathematician and an expert in the Lean 4 Proof Assistant.'
|
SYSTEM_PROMPT_SKETCH_V0 = 'You are an expert mathematician and an expert in the Lean 4 Proof Assistant.'
|
||||||
STOP_TOKENS_SKETCH_V0: list[str] = ['Informal:', '(*### Problem', '###Solution', 'Formal:']
|
STOP_TOKENS_SKETCH_V0: list[str] = ['Informal:', '(*### Problem', '###Solution', 'Formal:']
|
||||||
prompt_sketch_template_lean4_v0 = ("Translate the informal solution into a sketch in the "
|
prompt_sketch_template_lean4_v0 = ("Translate the informal solution into a sketch in the "
|
||||||
"formal Lean 4 proof. Add <TODO_PROOF_OR_HAMMER> in the formal sketch whenever possible. "
|
f"formal Lean 4 proof. Add {TOKEN_PLACEHOLDER} in the formal sketch whenever possible. "
|
||||||
"<TODO_PROOF_OR_HAMMER> will be used to call a automated theorem prover or tactic in Lean 4. "
|
f"{TOKEN_PLACEHOLDER} will be used to call a automated theorem prover or tactic in Lean 4. "
|
||||||
"Here are some examples:\n"
|
"Here are some examples:\n"
|
||||||
)
|
)
|
||||||
def get_prompt_sketch_template_4_lean_v0(
|
def get_prompt_sketch_template_4_lean_v0(
|
||||||
|
@ -139,3 +140,54 @@ def get_prompt_sketch_template_4_lean_v0(
|
||||||
print(prompt_sketch_template_4_lean) if verbose else None
|
print(prompt_sketch_template_4_lean) if verbose else None
|
||||||
return prompt_sketch_template_4_lean
|
return prompt_sketch_template_4_lean
|
||||||
prompt_sketch_template_lean4_v0 = get_prompt_sketch_template_4_lean_v0()
|
prompt_sketch_template_lean4_v0 = get_prompt_sketch_template_4_lean_v0()
|
||||||
|
|
||||||
|
WALL = "```"
|
||||||
|
|
||||||
|
|
||||||
|
def extract_lean_code(
|
||||||
|
sketch: str,
|
||||||
|
placeholder: str = TOKEN_PLACEHOLDER,
|
||||||
|
strip_imports: bool = True) -> list[str]:
|
||||||
|
lines = sketch.split("\n")
|
||||||
|
# find backtick markers ```
|
||||||
|
lean_codes = []
|
||||||
|
curr = []
|
||||||
|
is_walled = False
|
||||||
|
is_walled_lean = False
|
||||||
|
for line in lines:
|
||||||
|
if not is_walled:
|
||||||
|
if line.rstrip() == f"{WALL}lean":
|
||||||
|
is_walled = True
|
||||||
|
is_walled_lean = True
|
||||||
|
elif line.startswith(WALL):
|
||||||
|
is_walled = True
|
||||||
|
is_walled_lean = False
|
||||||
|
continue
|
||||||
|
|
||||||
|
if line.rstrip() == WALL:
|
||||||
|
if is_walled_lean:
|
||||||
|
code = "\n".join(curr) + "\n"
|
||||||
|
code = code.replace("ℕ", "Nat").replace(placeholder, "sorry")
|
||||||
|
lean_codes.append(code)
|
||||||
|
is_walled = False
|
||||||
|
is_walled_lean = False
|
||||||
|
continue
|
||||||
|
|
||||||
|
if strip_imports and line.startswith("import "):
|
||||||
|
continue
|
||||||
|
curr.append(line)
|
||||||
|
|
||||||
|
return lean_codes
|
||||||
|
|
||||||
|
|
||||||
|
class TestPrompts(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_extract_lean_code(self):
|
||||||
|
sketch = "```lean\nimport Mathlib.Data.Nat.Basic\nimport Aesop\n\ntheorem n_plus_zero : ∀ n : ℕ, n + 0 = n := by\n -- Consider any natural number n. We need to show that n + 0 = n.\n -- Use the fact that adding zero to any natural number does not change its value.\n have h_nat_add_zero: ∀ n : ℕ, n + 0 = n := <TODO_PROOF_OR_HAMMER>\n -- Combine facts to close goal\n <TODO_PROOF_OR_HAMMER>\n```"
|
||||||
|
codes = extract_lean_code(sketch)
|
||||||
|
self.assertEqual(codes, [
|
||||||
|
"import Mathlib.Data.Nat.Basic\nimport Aesop\n\ntheorem n_plus_zero : ∀ n : Nat, n + 0 = n := by\n -- Consider any natural number n. We need to show that n + 0 = n.\n -- Use the fact that adding zero to any natural number does not change its value.\n have h_nat_add_zero: ∀ n : Nat, n + 0 = n := sorry\n -- Combine facts to close goal\n sorry\n"
|
||||||
|
])
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
|
|
Loading…
Reference in New Issue