feat: Error feedback in DSP
This commit is contained in:
parent
1ecfa35e1c
commit
7770c0fb59
|
@ -1,12 +1,13 @@
|
|||
import sys, os, json, subprocess
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, asdict, field
|
||||
from pathlib import Path
|
||||
from typing import Union, Any, Tuple
|
||||
from typing import Union, Any, Tuple, Optional
|
||||
from tqdm import tqdm
|
||||
from openai import OpenAI
|
||||
import wandb
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
from pantograph import Server
|
||||
from pantograph import Server, ServerError
|
||||
from pantograph.search import SearchResult
|
||||
from termcolor import colored
|
||||
|
||||
from solve.prompts import (
|
||||
|
@ -32,6 +33,20 @@ class SamplingParams:
|
|||
temperature: float
|
||||
stop: str
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SketchParseFailure:
|
||||
error: str
|
||||
sketch: str
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DatumResult:
|
||||
"""
|
||||
Result from one DSP data point
|
||||
"""
|
||||
name: str
|
||||
success: Optional[bool] = False
|
||||
proves: list[Union[SearchResult, SketchParseFailure]] = field(default_factory=list)
|
||||
|
||||
class Engine:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
@ -170,8 +185,8 @@ def prove(
|
|||
eng: Engine,
|
||||
server: Server,
|
||||
fl_prob: str,
|
||||
fl_sketch: list[str],
|
||||
):
|
||||
fl_sketch: str,
|
||||
) -> list[SearchResult]:
|
||||
"""
|
||||
|
||||
Complete formal sketch and check if it proves the theorem.
|
||||
|
@ -182,23 +197,44 @@ def prove(
|
|||
"""
|
||||
# If this throws index out of bound errors it means the source doesn't contain walled off Lean sections.
|
||||
print(colored("Sketch:", "yellow"), fl_sketch)
|
||||
lean_code, = [extract_lean_code(sketch)[0] for sketch in fl_sketch]
|
||||
lean_code = "\n".join(extract_lean_code(fl_sketch))
|
||||
print(colored("Lean code:", "light_grey"), lean_code)
|
||||
states = server.load_sorry(lean_code)
|
||||
|
||||
try:
|
||||
states = server.load_sorry(lean_code)
|
||||
except ServerError as e:
|
||||
msg = f"Encountered exception: {e}"
|
||||
print(colored(msg, "red"))
|
||||
return SketchParseFailure(
|
||||
sketch=fl_sketch,
|
||||
error=msg,
|
||||
)
|
||||
|
||||
if len(states) != 1:
|
||||
print(colored("Model must output one compilation unit", "red"))
|
||||
raise NotImplemented
|
||||
return SketchParseFailure(
|
||||
sketch=fl_sketch,
|
||||
error="Model must output one compilation unit",
|
||||
)
|
||||
|
||||
state = states[0]
|
||||
|
||||
if isinstance(state, list) and len(state) > 0:
|
||||
print(colored("Sketch failed:", "red"), "\n".join(state))
|
||||
# what should we do?
|
||||
raise NotImplemented
|
||||
# This means `state` contains error messages
|
||||
msg = "\n".join(state)
|
||||
print(colored("Sketch failed:", "red"), msg)
|
||||
return SketchParseFailure(
|
||||
sketch=fl_sketch,
|
||||
error=f"Sketch failed: {msg}",
|
||||
)
|
||||
|
||||
agent = HammerAgent()
|
||||
result = agent.search(server, state)
|
||||
result = agent.search(
|
||||
server,
|
||||
state,
|
||||
max_steps=1000,
|
||||
max_trials_per_goal=len(agent.tactics) + 1,
|
||||
)
|
||||
print(colored(f"Result: {result}", "blue"))
|
||||
|
||||
return result
|
||||
|
@ -207,37 +243,62 @@ def prove(
|
|||
|
||||
def single_proof_search_dsp_lean(
|
||||
eng: Engine,
|
||||
server: Server,
|
||||
server_func,
|
||||
datum: Datum,
|
||||
) -> bool:
|
||||
) -> DatumResult:
|
||||
# -- Draft: [y_nl_pred_draft]_n ~ draft(eng, x_nl_prob, P_draft)
|
||||
y_nl_pred_drafts = draft(eng, datum)
|
||||
|
||||
# -- Sketch: z_fl_pred_sketch ~ sketch(eng, x_nl_prob, [y_nl_pred_draft]_n, x_fl_prob, P_sketch)
|
||||
z_fl_pred_sketches, x_fl_prob = sketch(eng, datum, y_nl_pred_drafts)
|
||||
|
||||
# -- Prove: y_fl = prove(eng, x_fl_prob, z_fl_pred_sketches)
|
||||
result: bool = prove(eng, server, x_fl_prob, z_fl_pred_sketches)
|
||||
assert len(z_fl_pred_sketches) == 1
|
||||
|
||||
# -- Return
|
||||
return result
|
||||
server = server_func()
|
||||
|
||||
# -- Prove: y_fl = prove(eng, x_fl_prob, z_fl_pred_sketches)
|
||||
prove_result = [prove(eng, server, x_fl_prob, sketch) for sketch in z_fl_pred_sketches]
|
||||
|
||||
return DatumResult(
|
||||
name=str(datum),
|
||||
success=any(
|
||||
x.success for x in prove_result
|
||||
if isinstance(x, SearchResult)
|
||||
),
|
||||
proves=prove_result,
|
||||
)
|
||||
|
||||
def full_proof_search_dsp_lean(
|
||||
eng: Engine,
|
||||
server: Server,
|
||||
server_func,
|
||||
data: list[Datum],
|
||||
path_output: Path,
|
||||
):
|
||||
print(colored(f"DSP on {len(data)} points", "blue", attrs=["bold", "underline"]))
|
||||
n_success = 0
|
||||
# -- Proof search by DSP over all eval data
|
||||
for i, datum in tqdm(enumerate(data), total=len(data), desc='DSP proof loop per data point in benchmark.'):
|
||||
print(f"Problem {i}:", colored(str(datum), "cyan"))
|
||||
result = single_proof_search_dsp_lean(eng, server, datum)
|
||||
file_name = path_output / f"{i:03}.json"
|
||||
key = str(datum)
|
||||
# Detect if file exists
|
||||
if file_name.is_file():
|
||||
obj = json.load(open(file_name, "r"))
|
||||
if obj['name'] != key:
|
||||
print(colored(f"Existing datum name {obj['name']} does not match dataset {key}. The output directory may be wrong"))
|
||||
break
|
||||
|
||||
print(f"Skipped {i}:", colored(key, "green"))
|
||||
continue
|
||||
|
||||
print(f"Problem {i}:", colored(key, "cyan"))
|
||||
|
||||
result = single_proof_search_dsp_lean(eng, server_func, datum)
|
||||
with open(file_name, 'w') as f:
|
||||
json.dump({ 'name': str(datum), 'success': result.success, 'steps': result.steps }, f)
|
||||
json.dump(asdict(result), f)
|
||||
if result.success:
|
||||
n_success += 1
|
||||
#server.gc()
|
||||
return
|
||||
print(f"Proved {n_success}/{len(data)} problems")
|
||||
|
||||
|
||||
experiment_dir = Path(__file__).resolve().parent
|
||||
|
@ -279,11 +340,12 @@ def main(args):
|
|||
|
||||
# Start server
|
||||
project_path, lean_path = get_project_and_lean_path()
|
||||
server = Server(
|
||||
imports=["Mathlib", "Aesop"],
|
||||
project_path=project_path,
|
||||
lean_path=lean_path,
|
||||
)
|
||||
def server_func():
|
||||
return Server(
|
||||
imports=["Mathlib", "Aesop"],
|
||||
project_path=project_path,
|
||||
lean_path=lean_path,
|
||||
)
|
||||
|
||||
# - Start wandb run
|
||||
# print(f'\n\n-- Setup params')
|
||||
|
@ -322,7 +384,7 @@ def main(args):
|
|||
)
|
||||
|
||||
# - Full proof search with DSP
|
||||
full_proof_search_dsp_lean(eng, server, data_eval, path_output)
|
||||
full_proof_search_dsp_lean(eng, server_func, data_eval, path_output)
|
||||
|
||||
dt = datetime.timedelta(seconds=time.time() - start_time)
|
||||
print(colored(f"Time elapsed: {dt}", "magenta"))
|
||||
|
|
|
@ -150,6 +150,10 @@ def extract_lean_code(
|
|||
strip_imports: bool = True) -> list[str]:
|
||||
lines = sketch.split("\n")
|
||||
# find backtick markers ```
|
||||
if WALL not in sketch:
|
||||
# No walls found. The whole thing must be code
|
||||
lines = [line for line in lines if not line.startswith("import ")]
|
||||
return ["\n".join(lines)]
|
||||
lean_codes = []
|
||||
curr = []
|
||||
is_walled = False
|
||||
|
@ -196,5 +200,11 @@ class TestPrompts(unittest.TestCase):
|
|||
codes = extract_lean_code(sketch)
|
||||
self.assertEqual(len(codes), 1)
|
||||
|
||||
def test_extract_sketch_no_wall(self):
|
||||
payload = "example : forall (n: Prop), n -> n := sorry"
|
||||
sketch = f"import Mathlib\n\n{payload}"
|
||||
codes = extract_lean_code(sketch)
|
||||
self.assertEqual(codes, ["\n" + payload])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -1 +1 @@
|
|||
from pantograph.server import Server
|
||||
from pantograph.server import Server, ServerError
|
||||
|
|
Loading…
Reference in New Issue