From 7770c0fb5995249635a7f1e3c3ba6b6301566e18 Mon Sep 17 00:00:00 2001 From: Leni Aniva Date: Sun, 6 Oct 2024 19:14:38 -0700 Subject: [PATCH] feat: Error feedback in DSP --- experiments/dsp/main.py | 120 +++++++++++++++++++++++-------- experiments/dsp/solve/prompts.py | 10 +++ pantograph/__init__.py | 2 +- 3 files changed, 102 insertions(+), 30 deletions(-) diff --git a/experiments/dsp/main.py b/experiments/dsp/main.py index 14f8f70..00ae10d 100644 --- a/experiments/dsp/main.py +++ b/experiments/dsp/main.py @@ -1,12 +1,13 @@ import sys, os, json, subprocess -from dataclasses import dataclass +from dataclasses import dataclass, asdict, field from pathlib import Path -from typing import Union, Any, Tuple +from typing import Union, Any, Tuple, Optional from tqdm import tqdm from openai import OpenAI import wandb from tenacity import retry, stop_after_attempt, wait_exponential -from pantograph import Server +from pantograph import Server, ServerError +from pantograph.search import SearchResult from termcolor import colored from solve.prompts import ( @@ -32,6 +33,20 @@ class SamplingParams: temperature: float stop: str +@dataclass(frozen=True) +class SketchParseFailure: + error: str + sketch: str + +@dataclass(frozen=True) +class DatumResult: + """ + Result from one DSP data point + """ + name: str + success: Optional[bool] = False + proves: list[Union[SearchResult, SketchParseFailure]] = field(default_factory=list) + class Engine: def __init__(self): pass @@ -170,8 +185,8 @@ def prove( eng: Engine, server: Server, fl_prob: str, - fl_sketch: list[str], - ): + fl_sketch: str, + ) -> list[SearchResult]: """ Complete formal sketch and check if it proves the theorem. @@ -182,23 +197,44 @@ def prove( """ # If this throws index out of bound errors it means the source doesn't contain walled off Lean sections. print(colored("Sketch:", "yellow"), fl_sketch) - lean_code, = [extract_lean_code(sketch)[0] for sketch in fl_sketch] + lean_code = "\n".join(extract_lean_code(fl_sketch)) print(colored("Lean code:", "light_grey"), lean_code) - states = server.load_sorry(lean_code) + + try: + states = server.load_sorry(lean_code) + except ServerError as e: + msg = f"Encountered exception: {e}" + print(colored(msg, "red")) + return SketchParseFailure( + sketch=fl_sketch, + error=msg, + ) if len(states) != 1: print(colored("Model must output one compilation unit", "red")) - raise NotImplemented + return SketchParseFailure( + sketch=fl_sketch, + error="Model must output one compilation unit", + ) state = states[0] if isinstance(state, list) and len(state) > 0: - print(colored("Sketch failed:", "red"), "\n".join(state)) - # what should we do? - raise NotImplemented + # This means `state` contains error messages + msg = "\n".join(state) + print(colored("Sketch failed:", "red"), msg) + return SketchParseFailure( + sketch=fl_sketch, + error=f"Sketch failed: {msg}", + ) agent = HammerAgent() - result = agent.search(server, state) + result = agent.search( + server, + state, + max_steps=1000, + max_trials_per_goal=len(agent.tactics) + 1, + ) print(colored(f"Result: {result}", "blue")) return result @@ -207,37 +243,62 @@ def prove( def single_proof_search_dsp_lean( eng: Engine, - server: Server, + server_func, datum: Datum, - ) -> bool: + ) -> DatumResult: # -- Draft: [y_nl_pred_draft]_n ~ draft(eng, x_nl_prob, P_draft) y_nl_pred_drafts = draft(eng, datum) # -- Sketch: z_fl_pred_sketch ~ sketch(eng, x_nl_prob, [y_nl_pred_draft]_n, x_fl_prob, P_sketch) z_fl_pred_sketches, x_fl_prob = sketch(eng, datum, y_nl_pred_drafts) - # -- Prove: y_fl = prove(eng, x_fl_prob, z_fl_pred_sketches) - result: bool = prove(eng, server, x_fl_prob, z_fl_pred_sketches) + assert len(z_fl_pred_sketches) == 1 - # -- Return - return result + server = server_func() + + # -- Prove: y_fl = prove(eng, x_fl_prob, z_fl_pred_sketches) + prove_result = [prove(eng, server, x_fl_prob, sketch) for sketch in z_fl_pred_sketches] + + return DatumResult( + name=str(datum), + success=any( + x.success for x in prove_result + if isinstance(x, SearchResult) + ), + proves=prove_result, + ) def full_proof_search_dsp_lean( eng: Engine, - server: Server, + server_func, data: list[Datum], path_output: Path, ): print(colored(f"DSP on {len(data)} points", "blue", attrs=["bold", "underline"])) + n_success = 0 # -- Proof search by DSP over all eval data for i, datum in tqdm(enumerate(data), total=len(data), desc='DSP proof loop per data point in benchmark.'): - print(f"Problem {i}:", colored(str(datum), "cyan")) - result = single_proof_search_dsp_lean(eng, server, datum) file_name = path_output / f"{i:03}.json" + key = str(datum) + # Detect if file exists + if file_name.is_file(): + obj = json.load(open(file_name, "r")) + if obj['name'] != key: + print(colored(f"Existing datum name {obj['name']} does not match dataset {key}. The output directory may be wrong")) + break + + print(f"Skipped {i}:", colored(key, "green")) + continue + + print(f"Problem {i}:", colored(key, "cyan")) + + result = single_proof_search_dsp_lean(eng, server_func, datum) with open(file_name, 'w') as f: - json.dump({ 'name': str(datum), 'success': result.success, 'steps': result.steps }, f) + json.dump(asdict(result), f) + if result.success: + n_success += 1 #server.gc() - return + print(f"Proved {n_success}/{len(data)} problems") experiment_dir = Path(__file__).resolve().parent @@ -279,11 +340,12 @@ def main(args): # Start server project_path, lean_path = get_project_and_lean_path() - server = Server( - imports=["Mathlib", "Aesop"], - project_path=project_path, - lean_path=lean_path, - ) + def server_func(): + return Server( + imports=["Mathlib", "Aesop"], + project_path=project_path, + lean_path=lean_path, + ) # - Start wandb run # print(f'\n\n-- Setup params') @@ -322,7 +384,7 @@ def main(args): ) # - Full proof search with DSP - full_proof_search_dsp_lean(eng, server, data_eval, path_output) + full_proof_search_dsp_lean(eng, server_func, data_eval, path_output) dt = datetime.timedelta(seconds=time.time() - start_time) print(colored(f"Time elapsed: {dt}", "magenta")) diff --git a/experiments/dsp/solve/prompts.py b/experiments/dsp/solve/prompts.py index c8a11c7..bfd5231 100644 --- a/experiments/dsp/solve/prompts.py +++ b/experiments/dsp/solve/prompts.py @@ -150,6 +150,10 @@ def extract_lean_code( strip_imports: bool = True) -> list[str]: lines = sketch.split("\n") # find backtick markers ``` + if WALL not in sketch: + # No walls found. The whole thing must be code + lines = [line for line in lines if not line.startswith("import ")] + return ["\n".join(lines)] lean_codes = [] curr = [] is_walled = False @@ -196,5 +200,11 @@ class TestPrompts(unittest.TestCase): codes = extract_lean_code(sketch) self.assertEqual(len(codes), 1) + def test_extract_sketch_no_wall(self): + payload = "example : forall (n: Prop), n -> n := sorry" + sketch = f"import Mathlib\n\n{payload}" + codes = extract_lean_code(sketch) + self.assertEqual(codes, ["\n" + payload]) + if __name__ == '__main__': unittest.main() diff --git a/pantograph/__init__.py b/pantograph/__init__.py index 78c1e35..281c2df 100644 --- a/pantograph/__init__.py +++ b/pantograph/__init__.py @@ -1 +1 @@ -from pantograph.server import Server +from pantograph.server import Server, ServerError