feat: Extract Lean code sections from sketches

This commit is contained in:
Leni Aniva 2024-10-03 12:26:42 -07:00
parent f1e996baae
commit 80a356c75c
Signed by: aniva
GPG Key ID: 4D9B1C8D10EA4C50
2 changed files with 67 additions and 9 deletions

View File

@ -10,6 +10,7 @@ from tenacity import retry, stop_after_attempt, wait_exponential
from pantograph import Server from pantograph import Server
from solve.prompts import ( from solve.prompts import (
extract_lean_code,
SYSTEM_PROMPT_DRAFT_V0, SYSTEM_PROMPT_DRAFT_V0,
SYSTEM_PROMPT_SKETCH_V0, SYSTEM_PROMPT_SKETCH_V0,
prompt_draft_template_lean4_v0, prompt_draft_template_lean4_v0,
@ -58,7 +59,7 @@ class OpenAI_DSP_Engine(Engine):
verbose_init: bool = True, verbose_init: bool = True,
): ):
super().__init__() super().__init__()
print(f'{api_key=}, {base_url=}') if verbose_init else None print(f'{base_url=}') if verbose_init else None
if not ('gpt-4-' in model or 'gpt-3.5-' in model or 'gpt-4o' in model): if not ('gpt-4-' in model or 'gpt-3.5-' in model or 'gpt-4o' in model):
@ -161,7 +162,8 @@ def sketch(
return sketches, x_fl_problem return sketches, x_fl_problem
def prove( def prove(
eng, eng: Engine,
server: Server,
fl_prob: str, fl_prob: str,
fl_sketch: list[str], fl_sketch: list[str],
): ):
@ -173,8 +175,11 @@ def prove(
fl_sketch --> Lean4 Form Sketch --> have x have ha fl_sketch --> Lean4 Form Sketch --> have x have ha
""" """
print(f"fl_prob={fl_prob}") # If this throws index out of bound errors it means the source doesn't contain walled off Lean sections.
print(f"fl_sketch={fl_sketch}") lean_code, = [extract_lean_code(sketch)[0] for sketch in fl_sketch]
state, = server.load_sorry(lean_code)
print(state)
raise RuntimeError("Not implemented") raise RuntimeError("Not implemented")
# -- Prove # -- Prove
correct: bool = False correct: bool = False
@ -195,7 +200,7 @@ def single_proof_search_dsp_lean(
z_fl_pred_sketches, x_fl_prob = sketch(eng, data_pt, y_nl_pred_drafts) z_fl_pred_sketches, x_fl_prob = sketch(eng, data_pt, y_nl_pred_drafts)
# -- Prove: y_fl = prove(eng, x_fl_prob, z_fl_pred_sketches) # -- Prove: y_fl = prove(eng, x_fl_prob, z_fl_pred_sketches)
correct: bool = prove(eng, x_fl_prob, z_fl_pred_sketches) correct: bool = prove(eng, server, x_fl_prob, z_fl_pred_sketches)
# -- Return # -- Return
return correct return correct
@ -212,6 +217,7 @@ def full_proof_search_dsp_lean(
for data_pt in tqdm(eval_dataset, total=len(eval_dataset), desc='DSP proof loop per data point in benchmark.'): for data_pt in tqdm(eval_dataset, total=len(eval_dataset), desc='DSP proof loop per data point in benchmark.'):
print(f'{data_pt=}') print(f'{data_pt=}')
flag = single_proof_search_dsp_lean(eng, server, data_pt) flag = single_proof_search_dsp_lean(eng, server, data_pt)
server.gc()
return return
experiment_dir = Path(__file__).resolve().parent experiment_dir = Path(__file__).resolve().parent

View File

@ -7,8 +7,7 @@ core part of data for prompt for dsp:
"src_header_fl_problem": ..., #src_header_x*_fl "src_header_fl_problem": ..., #src_header_x*_fl
"fl_header_sketch": ..., # hz_fl suggested header "fl_header_sketch": ..., # hz_fl suggested header
""" """
import json import json, sys, unittest
import sys
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
@ -19,6 +18,8 @@ experiment_dir = Path(__file__).resolve().parent.parent
default_path_2_examples = 'debug/toy_example1_dsp/dsp_debug5_sf/dsp_debug5_sf_train.json' default_path_2_examples = 'debug/toy_example1_dsp/dsp_debug5_sf/dsp_debug5_sf_train.json'
TOKEN_PLACEHOLDER = "<TODO_PROOF_OR_HAMMER>"
# -- Prompt draft (P_draft) for Lean 4 # -- Prompt draft (P_draft) for Lean 4
""" """
Draft an informal solution similar to the one below. Draft an informal solution similar to the one below.
@ -98,8 +99,8 @@ Formal:\n
SYSTEM_PROMPT_SKETCH_V0 = 'You are an expert mathematician and an expert in the Lean 4 Proof Assistant.' SYSTEM_PROMPT_SKETCH_V0 = 'You are an expert mathematician and an expert in the Lean 4 Proof Assistant.'
STOP_TOKENS_SKETCH_V0: list[str] = ['Informal:', '(*### Problem', '###Solution', 'Formal:'] STOP_TOKENS_SKETCH_V0: list[str] = ['Informal:', '(*### Problem', '###Solution', 'Formal:']
prompt_sketch_template_lean4_v0 = ("Translate the informal solution into a sketch in the " prompt_sketch_template_lean4_v0 = ("Translate the informal solution into a sketch in the "
"formal Lean 4 proof. Add <TODO_PROOF_OR_HAMMER> in the formal sketch whenever possible. " f"formal Lean 4 proof. Add {TOKEN_PLACEHOLDER} in the formal sketch whenever possible. "
"<TODO_PROOF_OR_HAMMER> will be used to call a automated theorem prover or tactic in Lean 4. " f"{TOKEN_PLACEHOLDER} will be used to call a automated theorem prover or tactic in Lean 4. "
"Here are some examples:\n" "Here are some examples:\n"
) )
def get_prompt_sketch_template_4_lean_v0( def get_prompt_sketch_template_4_lean_v0(
@ -139,3 +140,54 @@ def get_prompt_sketch_template_4_lean_v0(
print(prompt_sketch_template_4_lean) if verbose else None print(prompt_sketch_template_4_lean) if verbose else None
return prompt_sketch_template_4_lean return prompt_sketch_template_4_lean
prompt_sketch_template_lean4_v0 = get_prompt_sketch_template_4_lean_v0() prompt_sketch_template_lean4_v0 = get_prompt_sketch_template_4_lean_v0()
WALL = "```"
def extract_lean_code(
sketch: str,
placeholder: str = TOKEN_PLACEHOLDER,
strip_imports: bool = True) -> list[str]:
lines = sketch.split("\n")
# find backtick markers ```
lean_codes = []
curr = []
is_walled = False
is_walled_lean = False
for line in lines:
if not is_walled:
if line.rstrip() == f"{WALL}lean":
is_walled = True
is_walled_lean = True
elif line.startswith(WALL):
is_walled = True
is_walled_lean = False
continue
if line.rstrip() == WALL:
if is_walled_lean:
code = "\n".join(curr) + "\n"
code = code.replace("", "Nat").replace(placeholder, "sorry")
lean_codes.append(code)
is_walled = False
is_walled_lean = False
continue
if strip_imports and line.startswith("import "):
continue
curr.append(line)
return lean_codes
class TestPrompts(unittest.TestCase):
def test_extract_lean_code(self):
sketch = "```lean\nimport Mathlib.Data.Nat.Basic\nimport Aesop\n\ntheorem n_plus_zero : ∀ n : , n + 0 = n := by\n -- Consider any natural number n. We need to show that n + 0 = n.\n -- Use the fact that adding zero to any natural number does not change its value.\n have h_nat_add_zero: ∀ n : , n + 0 = n := <TODO_PROOF_OR_HAMMER>\n -- Combine facts to close goal\n <TODO_PROOF_OR_HAMMER>\n```"
codes = extract_lean_code(sketch)
self.assertEqual(codes, [
"import Mathlib.Data.Nat.Basic\nimport Aesop\n\ntheorem n_plus_zero : ∀ n : Nat, n + 0 = n := by\n -- Consider any natural number n. We need to show that n + 0 = n.\n -- Use the fact that adding zero to any natural number does not change its value.\n have h_nat_add_zero: ∀ n : Nat, n + 0 = n := sorry\n -- Combine facts to close goal\n sorry\n"
])
if __name__ == '__main__':
unittest.main()