feat: Error feedback in DSP
This commit is contained in:
parent
1ecfa35e1c
commit
7770c0fb59
|
@ -1,12 +1,13 @@
|
||||||
import sys, os, json, subprocess
|
import sys, os, json, subprocess
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, asdict, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union, Any, Tuple
|
from typing import Union, Any, Tuple, Optional
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
import wandb
|
import wandb
|
||||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||||
from pantograph import Server
|
from pantograph import Server, ServerError
|
||||||
|
from pantograph.search import SearchResult
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
from solve.prompts import (
|
from solve.prompts import (
|
||||||
|
@ -32,6 +33,20 @@ class SamplingParams:
|
||||||
temperature: float
|
temperature: float
|
||||||
stop: str
|
stop: str
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class SketchParseFailure:
|
||||||
|
error: str
|
||||||
|
sketch: str
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class DatumResult:
|
||||||
|
"""
|
||||||
|
Result from one DSP data point
|
||||||
|
"""
|
||||||
|
name: str
|
||||||
|
success: Optional[bool] = False
|
||||||
|
proves: list[Union[SearchResult, SketchParseFailure]] = field(default_factory=list)
|
||||||
|
|
||||||
class Engine:
|
class Engine:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
@ -170,8 +185,8 @@ def prove(
|
||||||
eng: Engine,
|
eng: Engine,
|
||||||
server: Server,
|
server: Server,
|
||||||
fl_prob: str,
|
fl_prob: str,
|
||||||
fl_sketch: list[str],
|
fl_sketch: str,
|
||||||
):
|
) -> list[SearchResult]:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Complete formal sketch and check if it proves the theorem.
|
Complete formal sketch and check if it proves the theorem.
|
||||||
|
@ -182,23 +197,44 @@ def prove(
|
||||||
"""
|
"""
|
||||||
# If this throws index out of bound errors it means the source doesn't contain walled off Lean sections.
|
# If this throws index out of bound errors it means the source doesn't contain walled off Lean sections.
|
||||||
print(colored("Sketch:", "yellow"), fl_sketch)
|
print(colored("Sketch:", "yellow"), fl_sketch)
|
||||||
lean_code, = [extract_lean_code(sketch)[0] for sketch in fl_sketch]
|
lean_code = "\n".join(extract_lean_code(fl_sketch))
|
||||||
print(colored("Lean code:", "light_grey"), lean_code)
|
print(colored("Lean code:", "light_grey"), lean_code)
|
||||||
states = server.load_sorry(lean_code)
|
|
||||||
|
try:
|
||||||
|
states = server.load_sorry(lean_code)
|
||||||
|
except ServerError as e:
|
||||||
|
msg = f"Encountered exception: {e}"
|
||||||
|
print(colored(msg, "red"))
|
||||||
|
return SketchParseFailure(
|
||||||
|
sketch=fl_sketch,
|
||||||
|
error=msg,
|
||||||
|
)
|
||||||
|
|
||||||
if len(states) != 1:
|
if len(states) != 1:
|
||||||
print(colored("Model must output one compilation unit", "red"))
|
print(colored("Model must output one compilation unit", "red"))
|
||||||
raise NotImplemented
|
return SketchParseFailure(
|
||||||
|
sketch=fl_sketch,
|
||||||
|
error="Model must output one compilation unit",
|
||||||
|
)
|
||||||
|
|
||||||
state = states[0]
|
state = states[0]
|
||||||
|
|
||||||
if isinstance(state, list) and len(state) > 0:
|
if isinstance(state, list) and len(state) > 0:
|
||||||
print(colored("Sketch failed:", "red"), "\n".join(state))
|
# This means `state` contains error messages
|
||||||
# what should we do?
|
msg = "\n".join(state)
|
||||||
raise NotImplemented
|
print(colored("Sketch failed:", "red"), msg)
|
||||||
|
return SketchParseFailure(
|
||||||
|
sketch=fl_sketch,
|
||||||
|
error=f"Sketch failed: {msg}",
|
||||||
|
)
|
||||||
|
|
||||||
agent = HammerAgent()
|
agent = HammerAgent()
|
||||||
result = agent.search(server, state)
|
result = agent.search(
|
||||||
|
server,
|
||||||
|
state,
|
||||||
|
max_steps=1000,
|
||||||
|
max_trials_per_goal=len(agent.tactics) + 1,
|
||||||
|
)
|
||||||
print(colored(f"Result: {result}", "blue"))
|
print(colored(f"Result: {result}", "blue"))
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
@ -207,37 +243,62 @@ def prove(
|
||||||
|
|
||||||
def single_proof_search_dsp_lean(
|
def single_proof_search_dsp_lean(
|
||||||
eng: Engine,
|
eng: Engine,
|
||||||
server: Server,
|
server_func,
|
||||||
datum: Datum,
|
datum: Datum,
|
||||||
) -> bool:
|
) -> DatumResult:
|
||||||
# -- Draft: [y_nl_pred_draft]_n ~ draft(eng, x_nl_prob, P_draft)
|
# -- Draft: [y_nl_pred_draft]_n ~ draft(eng, x_nl_prob, P_draft)
|
||||||
y_nl_pred_drafts = draft(eng, datum)
|
y_nl_pred_drafts = draft(eng, datum)
|
||||||
|
|
||||||
# -- Sketch: z_fl_pred_sketch ~ sketch(eng, x_nl_prob, [y_nl_pred_draft]_n, x_fl_prob, P_sketch)
|
# -- Sketch: z_fl_pred_sketch ~ sketch(eng, x_nl_prob, [y_nl_pred_draft]_n, x_fl_prob, P_sketch)
|
||||||
z_fl_pred_sketches, x_fl_prob = sketch(eng, datum, y_nl_pred_drafts)
|
z_fl_pred_sketches, x_fl_prob = sketch(eng, datum, y_nl_pred_drafts)
|
||||||
|
|
||||||
# -- Prove: y_fl = prove(eng, x_fl_prob, z_fl_pred_sketches)
|
assert len(z_fl_pred_sketches) == 1
|
||||||
result: bool = prove(eng, server, x_fl_prob, z_fl_pred_sketches)
|
|
||||||
|
|
||||||
# -- Return
|
server = server_func()
|
||||||
return result
|
|
||||||
|
# -- Prove: y_fl = prove(eng, x_fl_prob, z_fl_pred_sketches)
|
||||||
|
prove_result = [prove(eng, server, x_fl_prob, sketch) for sketch in z_fl_pred_sketches]
|
||||||
|
|
||||||
|
return DatumResult(
|
||||||
|
name=str(datum),
|
||||||
|
success=any(
|
||||||
|
x.success for x in prove_result
|
||||||
|
if isinstance(x, SearchResult)
|
||||||
|
),
|
||||||
|
proves=prove_result,
|
||||||
|
)
|
||||||
|
|
||||||
def full_proof_search_dsp_lean(
|
def full_proof_search_dsp_lean(
|
||||||
eng: Engine,
|
eng: Engine,
|
||||||
server: Server,
|
server_func,
|
||||||
data: list[Datum],
|
data: list[Datum],
|
||||||
path_output: Path,
|
path_output: Path,
|
||||||
):
|
):
|
||||||
print(colored(f"DSP on {len(data)} points", "blue", attrs=["bold", "underline"]))
|
print(colored(f"DSP on {len(data)} points", "blue", attrs=["bold", "underline"]))
|
||||||
|
n_success = 0
|
||||||
# -- Proof search by DSP over all eval data
|
# -- Proof search by DSP over all eval data
|
||||||
for i, datum in tqdm(enumerate(data), total=len(data), desc='DSP proof loop per data point in benchmark.'):
|
for i, datum in tqdm(enumerate(data), total=len(data), desc='DSP proof loop per data point in benchmark.'):
|
||||||
print(f"Problem {i}:", colored(str(datum), "cyan"))
|
|
||||||
result = single_proof_search_dsp_lean(eng, server, datum)
|
|
||||||
file_name = path_output / f"{i:03}.json"
|
file_name = path_output / f"{i:03}.json"
|
||||||
|
key = str(datum)
|
||||||
|
# Detect if file exists
|
||||||
|
if file_name.is_file():
|
||||||
|
obj = json.load(open(file_name, "r"))
|
||||||
|
if obj['name'] != key:
|
||||||
|
print(colored(f"Existing datum name {obj['name']} does not match dataset {key}. The output directory may be wrong"))
|
||||||
|
break
|
||||||
|
|
||||||
|
print(f"Skipped {i}:", colored(key, "green"))
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"Problem {i}:", colored(key, "cyan"))
|
||||||
|
|
||||||
|
result = single_proof_search_dsp_lean(eng, server_func, datum)
|
||||||
with open(file_name, 'w') as f:
|
with open(file_name, 'w') as f:
|
||||||
json.dump({ 'name': str(datum), 'success': result.success, 'steps': result.steps }, f)
|
json.dump(asdict(result), f)
|
||||||
|
if result.success:
|
||||||
|
n_success += 1
|
||||||
#server.gc()
|
#server.gc()
|
||||||
return
|
print(f"Proved {n_success}/{len(data)} problems")
|
||||||
|
|
||||||
|
|
||||||
experiment_dir = Path(__file__).resolve().parent
|
experiment_dir = Path(__file__).resolve().parent
|
||||||
|
@ -279,11 +340,12 @@ def main(args):
|
||||||
|
|
||||||
# Start server
|
# Start server
|
||||||
project_path, lean_path = get_project_and_lean_path()
|
project_path, lean_path = get_project_and_lean_path()
|
||||||
server = Server(
|
def server_func():
|
||||||
imports=["Mathlib", "Aesop"],
|
return Server(
|
||||||
project_path=project_path,
|
imports=["Mathlib", "Aesop"],
|
||||||
lean_path=lean_path,
|
project_path=project_path,
|
||||||
)
|
lean_path=lean_path,
|
||||||
|
)
|
||||||
|
|
||||||
# - Start wandb run
|
# - Start wandb run
|
||||||
# print(f'\n\n-- Setup params')
|
# print(f'\n\n-- Setup params')
|
||||||
|
@ -322,7 +384,7 @@ def main(args):
|
||||||
)
|
)
|
||||||
|
|
||||||
# - Full proof search with DSP
|
# - Full proof search with DSP
|
||||||
full_proof_search_dsp_lean(eng, server, data_eval, path_output)
|
full_proof_search_dsp_lean(eng, server_func, data_eval, path_output)
|
||||||
|
|
||||||
dt = datetime.timedelta(seconds=time.time() - start_time)
|
dt = datetime.timedelta(seconds=time.time() - start_time)
|
||||||
print(colored(f"Time elapsed: {dt}", "magenta"))
|
print(colored(f"Time elapsed: {dt}", "magenta"))
|
||||||
|
|
|
@ -150,6 +150,10 @@ def extract_lean_code(
|
||||||
strip_imports: bool = True) -> list[str]:
|
strip_imports: bool = True) -> list[str]:
|
||||||
lines = sketch.split("\n")
|
lines = sketch.split("\n")
|
||||||
# find backtick markers ```
|
# find backtick markers ```
|
||||||
|
if WALL not in sketch:
|
||||||
|
# No walls found. The whole thing must be code
|
||||||
|
lines = [line for line in lines if not line.startswith("import ")]
|
||||||
|
return ["\n".join(lines)]
|
||||||
lean_codes = []
|
lean_codes = []
|
||||||
curr = []
|
curr = []
|
||||||
is_walled = False
|
is_walled = False
|
||||||
|
@ -196,5 +200,11 @@ class TestPrompts(unittest.TestCase):
|
||||||
codes = extract_lean_code(sketch)
|
codes = extract_lean_code(sketch)
|
||||||
self.assertEqual(len(codes), 1)
|
self.assertEqual(len(codes), 1)
|
||||||
|
|
||||||
|
def test_extract_sketch_no_wall(self):
|
||||||
|
payload = "example : forall (n: Prop), n -> n := sorry"
|
||||||
|
sketch = f"import Mathlib\n\n{payload}"
|
||||||
|
codes = extract_lean_code(sketch)
|
||||||
|
self.assertEqual(codes, ["\n" + payload])
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
from pantograph.server import Server
|
from pantograph.server import Server, ServerError
|
||||||
|
|
Loading…
Reference in New Issue