feat: Hammer agent for DSP, diagnostics
This commit is contained in:
parent
80a356c75c
commit
9fd930380d
|
@ -1,13 +1,13 @@
|
||||||
import sys, os, json
|
import sys, os, json
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union, Any
|
from typing import Union, Any, Tuple
|
||||||
from collections import namedtuple
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
import wandb
|
import wandb
|
||||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||||
from pantograph import Server
|
from pantograph import Server
|
||||||
|
from termcolor import colored
|
||||||
|
|
||||||
from solve.prompts import (
|
from solve.prompts import (
|
||||||
extract_lean_code,
|
extract_lean_code,
|
||||||
|
@ -19,6 +19,7 @@ from solve.prompts import (
|
||||||
STOP_TOKENS_SKETCH_V0,
|
STOP_TOKENS_SKETCH_V0,
|
||||||
get_prompt_sketch_template_4_lean_v0,
|
get_prompt_sketch_template_4_lean_v0,
|
||||||
)
|
)
|
||||||
|
from solve.prove import HammerAgent
|
||||||
|
|
||||||
# prompt_draft_template_lean4_v0 = "Draft an informal solution similar to the one below. The informal solution will be used to sketch a formal proof in the Lean 4 Proof Assistant. Here are some examples of informal problem solutions pairs:\n\nInformal:\n(*### Problem\n\nProve that for any natural number n, n + 0 = n.\n\n### Solution\n\nConsider any natural number n. From properties of addition, adding zero does not change its values. Thus, n + 0 = n.*)\n\nInformal:\n(*### Problem\n\nProve that for any natural number n, n + (m + 1) = (n + m) + 1.\n\n### Solution\n\nConsider any natural numbers n and m. From properties of addition, adding 1 to the sum of n and m is the same as first adding m to n and then adding 1. Thus, n + (m + 1) = (n + m) + 1.*)\n\nInformal:\n(*### Problem\n\nProve that for any natural number n and m, n + m = m + n.\n\n### Solution\n\nConsider any natural numbers n and m. We will do induction on n. Base case: 0 + m = m + 0 by properties of addition. Inductive step, we have n + m = m + n. Then (n + 1) + m = (n + m) + 1 = (m + n) + 1 = m + (n + 1). Thus, by induction, n + m = m + n, qed.*)\n\nInformal: \n(*### Problem\n\n{nl_problem}\n\n### Solution\n"
|
# prompt_draft_template_lean4_v0 = "Draft an informal solution similar to the one below. The informal solution will be used to sketch a formal proof in the Lean 4 Proof Assistant. Here are some examples of informal problem solutions pairs:\n\nInformal:\n(*### Problem\n\nProve that for any natural number n, n + 0 = n.\n\n### Solution\n\nConsider any natural number n. From properties of addition, adding zero does not change its values. Thus, n + 0 = n.*)\n\nInformal:\n(*### Problem\n\nProve that for any natural number n, n + (m + 1) = (n + m) + 1.\n\n### Solution\n\nConsider any natural numbers n and m. From properties of addition, adding 1 to the sum of n and m is the same as first adding m to n and then adding 1. Thus, n + (m + 1) = (n + m) + 1.*)\n\nInformal:\n(*### Problem\n\nProve that for any natural number n and m, n + m = m + n.\n\n### Solution\n\nConsider any natural numbers n and m. We will do induction on n. Base case: 0 + m = m + 0 by properties of addition. Inductive step, we have n + m = m + n. Then (n + 1) + m = (n + m) + 1 = (m + n) + 1 = m + (n + 1). Thus, by induction, n + m = m + n, qed.*)\n\nInformal: \n(*### Problem\n\n{nl_problem}\n\n### Solution\n"
|
||||||
|
|
||||||
|
@ -127,7 +128,7 @@ def sketch(
|
||||||
drafts: list,
|
drafts: list,
|
||||||
autoformalize_prob_in_prompt: bool = False,
|
autoformalize_prob_in_prompt: bool = False,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
) -> list:
|
) -> Tuple[list[str], str]:
|
||||||
"""
|
"""
|
||||||
Creates (formal fl) sketch (fl proof sketch) for latter use in a formal proof sketch.
|
Creates (formal fl) sketch (fl proof sketch) for latter use in a formal proof sketch.
|
||||||
z_pred_fl ~ sketch(eng, x_nl_prob, y_pred_nl, x_fl_prob, P_sketch)
|
z_pred_fl ~ sketch(eng, x_nl_prob, y_pred_nl, x_fl_prob, P_sketch)
|
||||||
|
@ -178,13 +179,12 @@ def prove(
|
||||||
# If this throws index out of bound errors it means the source doesn't contain walled off Lean sections.
|
# If this throws index out of bound errors it means the source doesn't contain walled off Lean sections.
|
||||||
lean_code, = [extract_lean_code(sketch)[0] for sketch in fl_sketch]
|
lean_code, = [extract_lean_code(sketch)[0] for sketch in fl_sketch]
|
||||||
state, = server.load_sorry(lean_code)
|
state, = server.load_sorry(lean_code)
|
||||||
|
agent = HammerAgent()
|
||||||
|
result = agent.search(server, state, verbose=True)
|
||||||
|
print(colored(f"Result: {result}", "blue"))
|
||||||
|
|
||||||
print(state)
|
|
||||||
raise RuntimeError("Not implemented")
|
raise RuntimeError("Not implemented")
|
||||||
# -- Prove
|
return
|
||||||
correct: bool = False
|
|
||||||
# -- Return
|
|
||||||
return correct
|
|
||||||
|
|
||||||
# -- DSP for Lean
|
# -- DSP for Lean
|
||||||
|
|
||||||
|
@ -200,10 +200,10 @@ def single_proof_search_dsp_lean(
|
||||||
z_fl_pred_sketches, x_fl_prob = sketch(eng, data_pt, y_nl_pred_drafts)
|
z_fl_pred_sketches, x_fl_prob = sketch(eng, data_pt, y_nl_pred_drafts)
|
||||||
|
|
||||||
# -- Prove: y_fl = prove(eng, x_fl_prob, z_fl_pred_sketches)
|
# -- Prove: y_fl = prove(eng, x_fl_prob, z_fl_pred_sketches)
|
||||||
correct: bool = prove(eng, server, x_fl_prob, z_fl_pred_sketches)
|
result: bool = prove(eng, server, x_fl_prob, z_fl_pred_sketches)
|
||||||
|
|
||||||
# -- Return
|
# -- Return
|
||||||
return correct
|
return result
|
||||||
|
|
||||||
def full_proof_search_dsp_lean(
|
def full_proof_search_dsp_lean(
|
||||||
eng: Engine,
|
eng: Engine,
|
||||||
|
@ -215,6 +215,7 @@ def full_proof_search_dsp_lean(
|
||||||
print(f'{len(eval_dataset)=}')
|
print(f'{len(eval_dataset)=}')
|
||||||
# -- Proof search by DSP over all eval data
|
# -- Proof search by DSP over all eval data
|
||||||
for data_pt in tqdm(eval_dataset, total=len(eval_dataset), desc='DSP proof loop per data point in benchmark.'):
|
for data_pt in tqdm(eval_dataset, total=len(eval_dataset), desc='DSP proof loop per data point in benchmark.'):
|
||||||
|
print("Problem:", colored(data_pt["nl_problem"][0], "green", attrs=["underline"]))
|
||||||
print(f'{data_pt=}')
|
print(f'{data_pt=}')
|
||||||
flag = single_proof_search_dsp_lean(eng, server, data_pt)
|
flag = single_proof_search_dsp_lean(eng, server, data_pt)
|
||||||
server.gc()
|
server.gc()
|
||||||
|
@ -271,7 +272,8 @@ def main(args):
|
||||||
# - Full proof search with DSP
|
# - Full proof search with DSP
|
||||||
print(f'\n\n-- Full proof search with DSP')
|
print(f'\n\n-- Full proof search with DSP')
|
||||||
full_proof_search_dsp_lean(eng, server, path_2_eval_dataset)
|
full_proof_search_dsp_lean(eng, server, path_2_eval_dataset)
|
||||||
print(f"Time taken: {time.time() - start_time:.2f} seconds, or {(time.time() - start_time) / 60:.2f} minutes, or {(time.time() - start_time) / 3600:.2f} hours.\a")
|
msg = f"Time taken: {time.time() - start_time:.2f} seconds, or {(time.time() - start_time) / 60:.2f} minutes, or {(time.time() - start_time) / 3600:.2f} hours.\a"
|
||||||
|
print(colored(msg, "magenta"))
|
||||||
|
|
||||||
# - End run
|
# - End run
|
||||||
# wandb.config.update(config)
|
# wandb.config.update(config)
|
||||||
|
@ -296,7 +298,12 @@ if __name__ == "__main__":
|
||||||
help="Evaluation dataset path",
|
help="Evaluation dataset path",
|
||||||
default=experiment_dir / 'debug/toy_example1_dsp/dsp_debug5_sf/dsp_debug5_sf_train.json',
|
default=experiment_dir / 'debug/toy_example1_dsp/dsp_debug5_sf/dsp_debug5_sf_train.json',
|
||||||
)
|
)
|
||||||
parser.add_argument("--model", help="Model", default="gpt-4o", choices=["gpt2", "gpt-3.5-turbo", "gpt-4o", "deepseek-ai/deepseek-math-7b-instruct"])
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
help="Model",
|
||||||
|
default="gpt-4o",
|
||||||
|
choices=["gpt2", "gpt-3.5-turbo", "gpt-4o", "deepseek-ai/deepseek-math-7b-instruct"],
|
||||||
|
)
|
||||||
parser.add_argument("--start", default=0)
|
parser.add_argument("--start", default=0)
|
||||||
parser.add_argument("--end", default=sys.maxsize)
|
parser.add_argument("--end", default=sys.maxsize)
|
||||||
parser.add_argument("--batchsize", default=10, help="putnam has 348")
|
parser.add_argument("--batchsize", default=10, help="putnam has 348")
|
||||||
|
|
|
@ -0,0 +1,29 @@
|
||||||
|
import collections
|
||||||
|
from typing import Optional
|
||||||
|
from pantograph.search import Agent
|
||||||
|
from pantograph.expr import GoalState, Tactic
|
||||||
|
|
||||||
|
class HammerAgent(Agent):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.goal_tactic_id_map = collections.defaultdict(lambda : 0)
|
||||||
|
self.tactics = [
|
||||||
|
"aesop",
|
||||||
|
]
|
||||||
|
|
||||||
|
def next_tactic(
|
||||||
|
self,
|
||||||
|
state: GoalState,
|
||||||
|
goal_id: int,
|
||||||
|
informal_stmt: str,
|
||||||
|
informal_proof: str) -> Optional[Tactic]:
|
||||||
|
key = (state.state_id, goal_id)
|
||||||
|
i = self.goal_tactic_id_map[key]
|
||||||
|
|
||||||
|
if i >= len(self.tactics):
|
||||||
|
return None
|
||||||
|
|
||||||
|
self.goal_tactic_id_map[key] = i + 1
|
||||||
|
return self.tactics[i]
|
|
@ -3250,4 +3250,4 @@ test = ["websockets"]
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.10"
|
python-versions = "^3.10"
|
||||||
content-hash = "686e7f1af124ef2404bc6b46677850c581e8e74f3cab51992fac8e8578f88a3a"
|
content-hash = "d992431714365397c4080f70aa7d146d7819976703bce96637f574856283d704"
|
||||||
|
|
|
@ -26,6 +26,7 @@ openai = "^1.31.0"
|
||||||
tenacity = "8.3.0"
|
tenacity = "8.3.0"
|
||||||
tiktoken = "^0.7.0"
|
tiktoken = "^0.7.0"
|
||||||
wandb = "0.17.0"
|
wandb = "0.17.0"
|
||||||
|
termcolor = "^2.4.0"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core"]
|
requires = ["poetry-core"]
|
||||||
|
|
Loading…
Reference in New Issue