feat: Hammer agent for DSP, diagnostics
This commit is contained in:
parent
80a356c75c
commit
9fd930380d
|
@ -1,13 +1,13 @@
|
|||
import sys, os, json
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Union, Any
|
||||
from collections import namedtuple
|
||||
from typing import Union, Any, Tuple
|
||||
from tqdm import tqdm
|
||||
from openai import OpenAI
|
||||
import wandb
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
from pantograph import Server
|
||||
from termcolor import colored
|
||||
|
||||
from solve.prompts import (
|
||||
extract_lean_code,
|
||||
|
@ -19,6 +19,7 @@ from solve.prompts import (
|
|||
STOP_TOKENS_SKETCH_V0,
|
||||
get_prompt_sketch_template_4_lean_v0,
|
||||
)
|
||||
from solve.prove import HammerAgent
|
||||
|
||||
# prompt_draft_template_lean4_v0 = "Draft an informal solution similar to the one below. The informal solution will be used to sketch a formal proof in the Lean 4 Proof Assistant. Here are some examples of informal problem solutions pairs:\n\nInformal:\n(*### Problem\n\nProve that for any natural number n, n + 0 = n.\n\n### Solution\n\nConsider any natural number n. From properties of addition, adding zero does not change its values. Thus, n + 0 = n.*)\n\nInformal:\n(*### Problem\n\nProve that for any natural number n, n + (m + 1) = (n + m) + 1.\n\n### Solution\n\nConsider any natural numbers n and m. From properties of addition, adding 1 to the sum of n and m is the same as first adding m to n and then adding 1. Thus, n + (m + 1) = (n + m) + 1.*)\n\nInformal:\n(*### Problem\n\nProve that for any natural number n and m, n + m = m + n.\n\n### Solution\n\nConsider any natural numbers n and m. We will do induction on n. Base case: 0 + m = m + 0 by properties of addition. Inductive step, we have n + m = m + n. Then (n + 1) + m = (n + m) + 1 = (m + n) + 1 = m + (n + 1). Thus, by induction, n + m = m + n, qed.*)\n\nInformal: \n(*### Problem\n\n{nl_problem}\n\n### Solution\n"
|
||||
|
||||
|
@ -122,12 +123,12 @@ def draft(
|
|||
|
||||
@retry(stop=stop_after_attempt(15), wait=wait_exponential(multiplier=2, max=128))
|
||||
def sketch(
|
||||
eng,
|
||||
data_pt: dict,
|
||||
drafts: list,
|
||||
autoformalize_prob_in_prompt: bool = False,
|
||||
verbose: bool = False,
|
||||
) -> list:
|
||||
eng,
|
||||
data_pt: dict,
|
||||
drafts: list,
|
||||
autoformalize_prob_in_prompt: bool = False,
|
||||
verbose: bool = False,
|
||||
) -> Tuple[list[str], str]:
|
||||
"""
|
||||
Creates (formal fl) sketch (fl proof sketch) for latter use in a formal proof sketch.
|
||||
z_pred_fl ~ sketch(eng, x_nl_prob, y_pred_nl, x_fl_prob, P_sketch)
|
||||
|
@ -178,13 +179,12 @@ def prove(
|
|||
# If this throws index out of bound errors it means the source doesn't contain walled off Lean sections.
|
||||
lean_code, = [extract_lean_code(sketch)[0] for sketch in fl_sketch]
|
||||
state, = server.load_sorry(lean_code)
|
||||
agent = HammerAgent()
|
||||
result = agent.search(server, state, verbose=True)
|
||||
print(colored(f"Result: {result}", "blue"))
|
||||
|
||||
print(state)
|
||||
raise RuntimeError("Not implemented")
|
||||
# -- Prove
|
||||
correct: bool = False
|
||||
# -- Return
|
||||
return correct
|
||||
return
|
||||
|
||||
# -- DSP for Lean
|
||||
|
||||
|
@ -200,10 +200,10 @@ def single_proof_search_dsp_lean(
|
|||
z_fl_pred_sketches, x_fl_prob = sketch(eng, data_pt, y_nl_pred_drafts)
|
||||
|
||||
# -- Prove: y_fl = prove(eng, x_fl_prob, z_fl_pred_sketches)
|
||||
correct: bool = prove(eng, server, x_fl_prob, z_fl_pred_sketches)
|
||||
result: bool = prove(eng, server, x_fl_prob, z_fl_pred_sketches)
|
||||
|
||||
# -- Return
|
||||
return correct
|
||||
return result
|
||||
|
||||
def full_proof_search_dsp_lean(
|
||||
eng: Engine,
|
||||
|
@ -215,6 +215,7 @@ def full_proof_search_dsp_lean(
|
|||
print(f'{len(eval_dataset)=}')
|
||||
# -- Proof search by DSP over all eval data
|
||||
for data_pt in tqdm(eval_dataset, total=len(eval_dataset), desc='DSP proof loop per data point in benchmark.'):
|
||||
print("Problem:", colored(data_pt["nl_problem"][0], "green", attrs=["underline"]))
|
||||
print(f'{data_pt=}')
|
||||
flag = single_proof_search_dsp_lean(eng, server, data_pt)
|
||||
server.gc()
|
||||
|
@ -271,7 +272,8 @@ def main(args):
|
|||
# - Full proof search with DSP
|
||||
print(f'\n\n-- Full proof search with DSP')
|
||||
full_proof_search_dsp_lean(eng, server, path_2_eval_dataset)
|
||||
print(f"Time taken: {time.time() - start_time:.2f} seconds, or {(time.time() - start_time) / 60:.2f} minutes, or {(time.time() - start_time) / 3600:.2f} hours.\a")
|
||||
msg = f"Time taken: {time.time() - start_time:.2f} seconds, or {(time.time() - start_time) / 60:.2f} minutes, or {(time.time() - start_time) / 3600:.2f} hours.\a"
|
||||
print(colored(msg, "magenta"))
|
||||
|
||||
# - End run
|
||||
# wandb.config.update(config)
|
||||
|
@ -296,7 +298,12 @@ if __name__ == "__main__":
|
|||
help="Evaluation dataset path",
|
||||
default=experiment_dir / 'debug/toy_example1_dsp/dsp_debug5_sf/dsp_debug5_sf_train.json',
|
||||
)
|
||||
parser.add_argument("--model", help="Model", default="gpt-4o", choices=["gpt2", "gpt-3.5-turbo", "gpt-4o", "deepseek-ai/deepseek-math-7b-instruct"])
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
help="Model",
|
||||
default="gpt-4o",
|
||||
choices=["gpt2", "gpt-3.5-turbo", "gpt-4o", "deepseek-ai/deepseek-math-7b-instruct"],
|
||||
)
|
||||
parser.add_argument("--start", default=0)
|
||||
parser.add_argument("--end", default=sys.maxsize)
|
||||
parser.add_argument("--batchsize", default=10, help="putnam has 348")
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
import collections
|
||||
from typing import Optional
|
||||
from pantograph.search import Agent
|
||||
from pantograph.expr import GoalState, Tactic
|
||||
|
||||
class HammerAgent(Agent):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.goal_tactic_id_map = collections.defaultdict(lambda : 0)
|
||||
self.tactics = [
|
||||
"aesop",
|
||||
]
|
||||
|
||||
def next_tactic(
|
||||
self,
|
||||
state: GoalState,
|
||||
goal_id: int,
|
||||
informal_stmt: str,
|
||||
informal_proof: str) -> Optional[Tactic]:
|
||||
key = (state.state_id, goal_id)
|
||||
i = self.goal_tactic_id_map[key]
|
||||
|
||||
if i >= len(self.tactics):
|
||||
return None
|
||||
|
||||
self.goal_tactic_id_map[key] = i + 1
|
||||
return self.tactics[i]
|
|
@ -3250,4 +3250,4 @@ test = ["websockets"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "686e7f1af124ef2404bc6b46677850c581e8e74f3cab51992fac8e8578f88a3a"
|
||||
content-hash = "d992431714365397c4080f70aa7d146d7819976703bce96637f574856283d704"
|
||||
|
|
|
@ -26,6 +26,7 @@ openai = "^1.31.0"
|
|||
tenacity = "8.3.0"
|
||||
tiktoken = "^0.7.0"
|
||||
wandb = "0.17.0"
|
||||
termcolor = "^2.4.0"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
|
Loading…
Reference in New Issue