feat: Error feedback in DSP

This commit is contained in:
Leni Aniva 2024-10-06 19:14:38 -07:00
parent 1ecfa35e1c
commit 7770c0fb59
Signed by: aniva
GPG Key ID: 4D9B1C8D10EA4C50
3 changed files with 102 additions and 30 deletions

View File

@ -1,12 +1,13 @@
import sys, os, json, subprocess
from dataclasses import dataclass
from dataclasses import dataclass, asdict, field
from pathlib import Path
from typing import Union, Any, Tuple
from typing import Union, Any, Tuple, Optional
from tqdm import tqdm
from openai import OpenAI
import wandb
from tenacity import retry, stop_after_attempt, wait_exponential
from pantograph import Server
from pantograph import Server, ServerError
from pantograph.search import SearchResult
from termcolor import colored
from solve.prompts import (
@ -32,6 +33,20 @@ class SamplingParams:
temperature: float
stop: str
@dataclass(frozen=True)
class SketchParseFailure:
error: str
sketch: str
@dataclass(frozen=True)
class DatumResult:
"""
Result from one DSP data point
"""
name: str
success: Optional[bool] = False
proves: list[Union[SearchResult, SketchParseFailure]] = field(default_factory=list)
class Engine:
def __init__(self):
pass
@ -170,8 +185,8 @@ def prove(
eng: Engine,
server: Server,
fl_prob: str,
fl_sketch: list[str],
):
fl_sketch: str,
) -> list[SearchResult]:
"""
Complete formal sketch and check if it proves the theorem.
@ -182,23 +197,44 @@ def prove(
"""
# If this throws index out of bound errors it means the source doesn't contain walled off Lean sections.
print(colored("Sketch:", "yellow"), fl_sketch)
lean_code, = [extract_lean_code(sketch)[0] for sketch in fl_sketch]
lean_code = "\n".join(extract_lean_code(fl_sketch))
print(colored("Lean code:", "light_grey"), lean_code)
try:
states = server.load_sorry(lean_code)
except ServerError as e:
msg = f"Encountered exception: {e}"
print(colored(msg, "red"))
return SketchParseFailure(
sketch=fl_sketch,
error=msg,
)
if len(states) != 1:
print(colored("Model must output one compilation unit", "red"))
raise NotImplemented
return SketchParseFailure(
sketch=fl_sketch,
error="Model must output one compilation unit",
)
state = states[0]
if isinstance(state, list) and len(state) > 0:
print(colored("Sketch failed:", "red"), "\n".join(state))
# what should we do?
raise NotImplemented
# This means `state` contains error messages
msg = "\n".join(state)
print(colored("Sketch failed:", "red"), msg)
return SketchParseFailure(
sketch=fl_sketch,
error=f"Sketch failed: {msg}",
)
agent = HammerAgent()
result = agent.search(server, state)
result = agent.search(
server,
state,
max_steps=1000,
max_trials_per_goal=len(agent.tactics) + 1,
)
print(colored(f"Result: {result}", "blue"))
return result
@ -207,37 +243,62 @@ def prove(
def single_proof_search_dsp_lean(
eng: Engine,
server: Server,
server_func,
datum: Datum,
) -> bool:
) -> DatumResult:
# -- Draft: [y_nl_pred_draft]_n ~ draft(eng, x_nl_prob, P_draft)
y_nl_pred_drafts = draft(eng, datum)
# -- Sketch: z_fl_pred_sketch ~ sketch(eng, x_nl_prob, [y_nl_pred_draft]_n, x_fl_prob, P_sketch)
z_fl_pred_sketches, x_fl_prob = sketch(eng, datum, y_nl_pred_drafts)
# -- Prove: y_fl = prove(eng, x_fl_prob, z_fl_pred_sketches)
result: bool = prove(eng, server, x_fl_prob, z_fl_pred_sketches)
assert len(z_fl_pred_sketches) == 1
# -- Return
return result
server = server_func()
# -- Prove: y_fl = prove(eng, x_fl_prob, z_fl_pred_sketches)
prove_result = [prove(eng, server, x_fl_prob, sketch) for sketch in z_fl_pred_sketches]
return DatumResult(
name=str(datum),
success=any(
x.success for x in prove_result
if isinstance(x, SearchResult)
),
proves=prove_result,
)
def full_proof_search_dsp_lean(
eng: Engine,
server: Server,
server_func,
data: list[Datum],
path_output: Path,
):
print(colored(f"DSP on {len(data)} points", "blue", attrs=["bold", "underline"]))
n_success = 0
# -- Proof search by DSP over all eval data
for i, datum in tqdm(enumerate(data), total=len(data), desc='DSP proof loop per data point in benchmark.'):
print(f"Problem {i}:", colored(str(datum), "cyan"))
result = single_proof_search_dsp_lean(eng, server, datum)
file_name = path_output / f"{i:03}.json"
key = str(datum)
# Detect if file exists
if file_name.is_file():
obj = json.load(open(file_name, "r"))
if obj['name'] != key:
print(colored(f"Existing datum name {obj['name']} does not match dataset {key}. The output directory may be wrong"))
break
print(f"Skipped {i}:", colored(key, "green"))
continue
print(f"Problem {i}:", colored(key, "cyan"))
result = single_proof_search_dsp_lean(eng, server_func, datum)
with open(file_name, 'w') as f:
json.dump({ 'name': str(datum), 'success': result.success, 'steps': result.steps }, f)
json.dump(asdict(result), f)
if result.success:
n_success += 1
#server.gc()
return
print(f"Proved {n_success}/{len(data)} problems")
experiment_dir = Path(__file__).resolve().parent
@ -279,7 +340,8 @@ def main(args):
# Start server
project_path, lean_path = get_project_and_lean_path()
server = Server(
def server_func():
return Server(
imports=["Mathlib", "Aesop"],
project_path=project_path,
lean_path=lean_path,
@ -322,7 +384,7 @@ def main(args):
)
# - Full proof search with DSP
full_proof_search_dsp_lean(eng, server, data_eval, path_output)
full_proof_search_dsp_lean(eng, server_func, data_eval, path_output)
dt = datetime.timedelta(seconds=time.time() - start_time)
print(colored(f"Time elapsed: {dt}", "magenta"))

View File

@ -150,6 +150,10 @@ def extract_lean_code(
strip_imports: bool = True) -> list[str]:
lines = sketch.split("\n")
# find backtick markers ```
if WALL not in sketch:
# No walls found. The whole thing must be code
lines = [line for line in lines if not line.startswith("import ")]
return ["\n".join(lines)]
lean_codes = []
curr = []
is_walled = False
@ -196,5 +200,11 @@ class TestPrompts(unittest.TestCase):
codes = extract_lean_code(sketch)
self.assertEqual(len(codes), 1)
def test_extract_sketch_no_wall(self):
payload = "example : forall (n: Prop), n -> n := sorry"
sketch = f"import Mathlib\n\n{payload}"
codes = extract_lean_code(sketch)
self.assertEqual(codes, ["\n" + payload])
if __name__ == '__main__':
unittest.main()

View File

@ -1 +1 @@
from pantograph.server import Server
from pantograph.server import Server, ServerError