feat: Error feedback in DSP

This commit is contained in:
Leni Aniva 2024-10-06 19:14:38 -07:00
parent 1ecfa35e1c
commit 7770c0fb59
Signed by: aniva
GPG Key ID: 4D9B1C8D10EA4C50
3 changed files with 102 additions and 30 deletions

View File

@ -1,12 +1,13 @@
import sys, os, json, subprocess import sys, os, json, subprocess
from dataclasses import dataclass from dataclasses import dataclass, asdict, field
from pathlib import Path from pathlib import Path
from typing import Union, Any, Tuple from typing import Union, Any, Tuple, Optional
from tqdm import tqdm from tqdm import tqdm
from openai import OpenAI from openai import OpenAI
import wandb import wandb
from tenacity import retry, stop_after_attempt, wait_exponential from tenacity import retry, stop_after_attempt, wait_exponential
from pantograph import Server from pantograph import Server, ServerError
from pantograph.search import SearchResult
from termcolor import colored from termcolor import colored
from solve.prompts import ( from solve.prompts import (
@ -32,6 +33,20 @@ class SamplingParams:
temperature: float temperature: float
stop: str stop: str
@dataclass(frozen=True)
class SketchParseFailure:
error: str
sketch: str
@dataclass(frozen=True)
class DatumResult:
"""
Result from one DSP data point
"""
name: str
success: Optional[bool] = False
proves: list[Union[SearchResult, SketchParseFailure]] = field(default_factory=list)
class Engine: class Engine:
def __init__(self): def __init__(self):
pass pass
@ -170,8 +185,8 @@ def prove(
eng: Engine, eng: Engine,
server: Server, server: Server,
fl_prob: str, fl_prob: str,
fl_sketch: list[str], fl_sketch: str,
): ) -> list[SearchResult]:
""" """
Complete formal sketch and check if it proves the theorem. Complete formal sketch and check if it proves the theorem.
@ -182,23 +197,44 @@ def prove(
""" """
# If this throws index out of bound errors it means the source doesn't contain walled off Lean sections. # If this throws index out of bound errors it means the source doesn't contain walled off Lean sections.
print(colored("Sketch:", "yellow"), fl_sketch) print(colored("Sketch:", "yellow"), fl_sketch)
lean_code, = [extract_lean_code(sketch)[0] for sketch in fl_sketch] lean_code = "\n".join(extract_lean_code(fl_sketch))
print(colored("Lean code:", "light_grey"), lean_code) print(colored("Lean code:", "light_grey"), lean_code)
states = server.load_sorry(lean_code)
try:
states = server.load_sorry(lean_code)
except ServerError as e:
msg = f"Encountered exception: {e}"
print(colored(msg, "red"))
return SketchParseFailure(
sketch=fl_sketch,
error=msg,
)
if len(states) != 1: if len(states) != 1:
print(colored("Model must output one compilation unit", "red")) print(colored("Model must output one compilation unit", "red"))
raise NotImplemented return SketchParseFailure(
sketch=fl_sketch,
error="Model must output one compilation unit",
)
state = states[0] state = states[0]
if isinstance(state, list) and len(state) > 0: if isinstance(state, list) and len(state) > 0:
print(colored("Sketch failed:", "red"), "\n".join(state)) # This means `state` contains error messages
# what should we do? msg = "\n".join(state)
raise NotImplemented print(colored("Sketch failed:", "red"), msg)
return SketchParseFailure(
sketch=fl_sketch,
error=f"Sketch failed: {msg}",
)
agent = HammerAgent() agent = HammerAgent()
result = agent.search(server, state) result = agent.search(
server,
state,
max_steps=1000,
max_trials_per_goal=len(agent.tactics) + 1,
)
print(colored(f"Result: {result}", "blue")) print(colored(f"Result: {result}", "blue"))
return result return result
@ -207,37 +243,62 @@ def prove(
def single_proof_search_dsp_lean( def single_proof_search_dsp_lean(
eng: Engine, eng: Engine,
server: Server, server_func,
datum: Datum, datum: Datum,
) -> bool: ) -> DatumResult:
# -- Draft: [y_nl_pred_draft]_n ~ draft(eng, x_nl_prob, P_draft) # -- Draft: [y_nl_pred_draft]_n ~ draft(eng, x_nl_prob, P_draft)
y_nl_pred_drafts = draft(eng, datum) y_nl_pred_drafts = draft(eng, datum)
# -- Sketch: z_fl_pred_sketch ~ sketch(eng, x_nl_prob, [y_nl_pred_draft]_n, x_fl_prob, P_sketch) # -- Sketch: z_fl_pred_sketch ~ sketch(eng, x_nl_prob, [y_nl_pred_draft]_n, x_fl_prob, P_sketch)
z_fl_pred_sketches, x_fl_prob = sketch(eng, datum, y_nl_pred_drafts) z_fl_pred_sketches, x_fl_prob = sketch(eng, datum, y_nl_pred_drafts)
# -- Prove: y_fl = prove(eng, x_fl_prob, z_fl_pred_sketches) assert len(z_fl_pred_sketches) == 1
result: bool = prove(eng, server, x_fl_prob, z_fl_pred_sketches)
# -- Return server = server_func()
return result
# -- Prove: y_fl = prove(eng, x_fl_prob, z_fl_pred_sketches)
prove_result = [prove(eng, server, x_fl_prob, sketch) for sketch in z_fl_pred_sketches]
return DatumResult(
name=str(datum),
success=any(
x.success for x in prove_result
if isinstance(x, SearchResult)
),
proves=prove_result,
)
def full_proof_search_dsp_lean( def full_proof_search_dsp_lean(
eng: Engine, eng: Engine,
server: Server, server_func,
data: list[Datum], data: list[Datum],
path_output: Path, path_output: Path,
): ):
print(colored(f"DSP on {len(data)} points", "blue", attrs=["bold", "underline"])) print(colored(f"DSP on {len(data)} points", "blue", attrs=["bold", "underline"]))
n_success = 0
# -- Proof search by DSP over all eval data # -- Proof search by DSP over all eval data
for i, datum in tqdm(enumerate(data), total=len(data), desc='DSP proof loop per data point in benchmark.'): for i, datum in tqdm(enumerate(data), total=len(data), desc='DSP proof loop per data point in benchmark.'):
print(f"Problem {i}:", colored(str(datum), "cyan"))
result = single_proof_search_dsp_lean(eng, server, datum)
file_name = path_output / f"{i:03}.json" file_name = path_output / f"{i:03}.json"
key = str(datum)
# Detect if file exists
if file_name.is_file():
obj = json.load(open(file_name, "r"))
if obj['name'] != key:
print(colored(f"Existing datum name {obj['name']} does not match dataset {key}. The output directory may be wrong"))
break
print(f"Skipped {i}:", colored(key, "green"))
continue
print(f"Problem {i}:", colored(key, "cyan"))
result = single_proof_search_dsp_lean(eng, server_func, datum)
with open(file_name, 'w') as f: with open(file_name, 'w') as f:
json.dump({ 'name': str(datum), 'success': result.success, 'steps': result.steps }, f) json.dump(asdict(result), f)
if result.success:
n_success += 1
#server.gc() #server.gc()
return print(f"Proved {n_success}/{len(data)} problems")
experiment_dir = Path(__file__).resolve().parent experiment_dir = Path(__file__).resolve().parent
@ -279,11 +340,12 @@ def main(args):
# Start server # Start server
project_path, lean_path = get_project_and_lean_path() project_path, lean_path = get_project_and_lean_path()
server = Server( def server_func():
imports=["Mathlib", "Aesop"], return Server(
project_path=project_path, imports=["Mathlib", "Aesop"],
lean_path=lean_path, project_path=project_path,
) lean_path=lean_path,
)
# - Start wandb run # - Start wandb run
# print(f'\n\n-- Setup params') # print(f'\n\n-- Setup params')
@ -322,7 +384,7 @@ def main(args):
) )
# - Full proof search with DSP # - Full proof search with DSP
full_proof_search_dsp_lean(eng, server, data_eval, path_output) full_proof_search_dsp_lean(eng, server_func, data_eval, path_output)
dt = datetime.timedelta(seconds=time.time() - start_time) dt = datetime.timedelta(seconds=time.time() - start_time)
print(colored(f"Time elapsed: {dt}", "magenta")) print(colored(f"Time elapsed: {dt}", "magenta"))

View File

@ -150,6 +150,10 @@ def extract_lean_code(
strip_imports: bool = True) -> list[str]: strip_imports: bool = True) -> list[str]:
lines = sketch.split("\n") lines = sketch.split("\n")
# find backtick markers ``` # find backtick markers ```
if WALL not in sketch:
# No walls found. The whole thing must be code
lines = [line for line in lines if not line.startswith("import ")]
return ["\n".join(lines)]
lean_codes = [] lean_codes = []
curr = [] curr = []
is_walled = False is_walled = False
@ -196,5 +200,11 @@ class TestPrompts(unittest.TestCase):
codes = extract_lean_code(sketch) codes = extract_lean_code(sketch)
self.assertEqual(len(codes), 1) self.assertEqual(len(codes), 1)
def test_extract_sketch_no_wall(self):
payload = "example : forall (n: Prop), n -> n := sorry"
sketch = f"import Mathlib\n\n{payload}"
codes = extract_lean_code(sketch)
self.assertEqual(codes, ["\n" + payload])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -1 +1 @@
from pantograph.server import Server from pantograph.server import Server, ServerError