feat: Hammer agent for DSP, diagnostics

This commit is contained in:
Leni Aniva 2024-10-04 18:36:52 -07:00
parent 80a356c75c
commit 9fd930380d
Signed by: aniva
GPG Key ID: 4D9B1C8D10EA4C50
4 changed files with 55 additions and 18 deletions

View File

@ -1,13 +1,13 @@
import sys, os, json import sys, os, json
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Union, Any from typing import Union, Any, Tuple
from collections import namedtuple
from tqdm import tqdm from tqdm import tqdm
from openai import OpenAI from openai import OpenAI
import wandb import wandb
from tenacity import retry, stop_after_attempt, wait_exponential from tenacity import retry, stop_after_attempt, wait_exponential
from pantograph import Server from pantograph import Server
from termcolor import colored
from solve.prompts import ( from solve.prompts import (
extract_lean_code, extract_lean_code,
@ -19,6 +19,7 @@ from solve.prompts import (
STOP_TOKENS_SKETCH_V0, STOP_TOKENS_SKETCH_V0,
get_prompt_sketch_template_4_lean_v0, get_prompt_sketch_template_4_lean_v0,
) )
from solve.prove import HammerAgent
# prompt_draft_template_lean4_v0 = "Draft an informal solution similar to the one below. The informal solution will be used to sketch a formal proof in the Lean 4 Proof Assistant. Here are some examples of informal problem solutions pairs:\n\nInformal:\n(*### Problem\n\nProve that for any natural number n, n + 0 = n.\n\n### Solution\n\nConsider any natural number n. From properties of addition, adding zero does not change its values. Thus, n + 0 = n.*)\n\nInformal:\n(*### Problem\n\nProve that for any natural number n, n + (m + 1) = (n + m) + 1.\n\n### Solution\n\nConsider any natural numbers n and m. From properties of addition, adding 1 to the sum of n and m is the same as first adding m to n and then adding 1. Thus, n + (m + 1) = (n + m) + 1.*)\n\nInformal:\n(*### Problem\n\nProve that for any natural number n and m, n + m = m + n.\n\n### Solution\n\nConsider any natural numbers n and m. We will do induction on n. Base case: 0 + m = m + 0 by properties of addition. Inductive step, we have n + m = m + n. Then (n + 1) + m = (n + m) + 1 = (m + n) + 1 = m + (n + 1). Thus, by induction, n + m = m + n, qed.*)\n\nInformal: \n(*### Problem\n\n{nl_problem}\n\n### Solution\n" # prompt_draft_template_lean4_v0 = "Draft an informal solution similar to the one below. The informal solution will be used to sketch a formal proof in the Lean 4 Proof Assistant. Here are some examples of informal problem solutions pairs:\n\nInformal:\n(*### Problem\n\nProve that for any natural number n, n + 0 = n.\n\n### Solution\n\nConsider any natural number n. From properties of addition, adding zero does not change its values. Thus, n + 0 = n.*)\n\nInformal:\n(*### Problem\n\nProve that for any natural number n, n + (m + 1) = (n + m) + 1.\n\n### Solution\n\nConsider any natural numbers n and m. From properties of addition, adding 1 to the sum of n and m is the same as first adding m to n and then adding 1. Thus, n + (m + 1) = (n + m) + 1.*)\n\nInformal:\n(*### Problem\n\nProve that for any natural number n and m, n + m = m + n.\n\n### Solution\n\nConsider any natural numbers n and m. We will do induction on n. Base case: 0 + m = m + 0 by properties of addition. Inductive step, we have n + m = m + n. Then (n + 1) + m = (n + m) + 1 = (m + n) + 1 = m + (n + 1). Thus, by induction, n + m = m + n, qed.*)\n\nInformal: \n(*### Problem\n\n{nl_problem}\n\n### Solution\n"
@ -122,12 +123,12 @@ def draft(
@retry(stop=stop_after_attempt(15), wait=wait_exponential(multiplier=2, max=128)) @retry(stop=stop_after_attempt(15), wait=wait_exponential(multiplier=2, max=128))
def sketch( def sketch(
eng, eng,
data_pt: dict, data_pt: dict,
drafts: list, drafts: list,
autoformalize_prob_in_prompt: bool = False, autoformalize_prob_in_prompt: bool = False,
verbose: bool = False, verbose: bool = False,
) -> list: ) -> Tuple[list[str], str]:
""" """
Creates (formal fl) sketch (fl proof sketch) for latter use in a formal proof sketch. Creates (formal fl) sketch (fl proof sketch) for latter use in a formal proof sketch.
z_pred_fl ~ sketch(eng, x_nl_prob, y_pred_nl, x_fl_prob, P_sketch) z_pred_fl ~ sketch(eng, x_nl_prob, y_pred_nl, x_fl_prob, P_sketch)
@ -178,13 +179,12 @@ def prove(
# If this throws index out of bound errors it means the source doesn't contain walled off Lean sections. # If this throws index out of bound errors it means the source doesn't contain walled off Lean sections.
lean_code, = [extract_lean_code(sketch)[0] for sketch in fl_sketch] lean_code, = [extract_lean_code(sketch)[0] for sketch in fl_sketch]
state, = server.load_sorry(lean_code) state, = server.load_sorry(lean_code)
agent = HammerAgent()
result = agent.search(server, state, verbose=True)
print(colored(f"Result: {result}", "blue"))
print(state)
raise RuntimeError("Not implemented") raise RuntimeError("Not implemented")
# -- Prove return
correct: bool = False
# -- Return
return correct
# -- DSP for Lean # -- DSP for Lean
@ -200,10 +200,10 @@ def single_proof_search_dsp_lean(
z_fl_pred_sketches, x_fl_prob = sketch(eng, data_pt, y_nl_pred_drafts) z_fl_pred_sketches, x_fl_prob = sketch(eng, data_pt, y_nl_pred_drafts)
# -- Prove: y_fl = prove(eng, x_fl_prob, z_fl_pred_sketches) # -- Prove: y_fl = prove(eng, x_fl_prob, z_fl_pred_sketches)
correct: bool = prove(eng, server, x_fl_prob, z_fl_pred_sketches) result: bool = prove(eng, server, x_fl_prob, z_fl_pred_sketches)
# -- Return # -- Return
return correct return result
def full_proof_search_dsp_lean( def full_proof_search_dsp_lean(
eng: Engine, eng: Engine,
@ -215,6 +215,7 @@ def full_proof_search_dsp_lean(
print(f'{len(eval_dataset)=}') print(f'{len(eval_dataset)=}')
# -- Proof search by DSP over all eval data # -- Proof search by DSP over all eval data
for data_pt in tqdm(eval_dataset, total=len(eval_dataset), desc='DSP proof loop per data point in benchmark.'): for data_pt in tqdm(eval_dataset, total=len(eval_dataset), desc='DSP proof loop per data point in benchmark.'):
print("Problem:", colored(data_pt["nl_problem"][0], "green", attrs=["underline"]))
print(f'{data_pt=}') print(f'{data_pt=}')
flag = single_proof_search_dsp_lean(eng, server, data_pt) flag = single_proof_search_dsp_lean(eng, server, data_pt)
server.gc() server.gc()
@ -271,7 +272,8 @@ def main(args):
# - Full proof search with DSP # - Full proof search with DSP
print(f'\n\n-- Full proof search with DSP') print(f'\n\n-- Full proof search with DSP')
full_proof_search_dsp_lean(eng, server, path_2_eval_dataset) full_proof_search_dsp_lean(eng, server, path_2_eval_dataset)
print(f"Time taken: {time.time() - start_time:.2f} seconds, or {(time.time() - start_time) / 60:.2f} minutes, or {(time.time() - start_time) / 3600:.2f} hours.\a") msg = f"Time taken: {time.time() - start_time:.2f} seconds, or {(time.time() - start_time) / 60:.2f} minutes, or {(time.time() - start_time) / 3600:.2f} hours.\a"
print(colored(msg, "magenta"))
# - End run # - End run
# wandb.config.update(config) # wandb.config.update(config)
@ -296,7 +298,12 @@ if __name__ == "__main__":
help="Evaluation dataset path", help="Evaluation dataset path",
default=experiment_dir / 'debug/toy_example1_dsp/dsp_debug5_sf/dsp_debug5_sf_train.json', default=experiment_dir / 'debug/toy_example1_dsp/dsp_debug5_sf/dsp_debug5_sf_train.json',
) )
parser.add_argument("--model", help="Model", default="gpt-4o", choices=["gpt2", "gpt-3.5-turbo", "gpt-4o", "deepseek-ai/deepseek-math-7b-instruct"]) parser.add_argument(
"--model",
help="Model",
default="gpt-4o",
choices=["gpt2", "gpt-3.5-turbo", "gpt-4o", "deepseek-ai/deepseek-math-7b-instruct"],
)
parser.add_argument("--start", default=0) parser.add_argument("--start", default=0)
parser.add_argument("--end", default=sys.maxsize) parser.add_argument("--end", default=sys.maxsize)
parser.add_argument("--batchsize", default=10, help="putnam has 348") parser.add_argument("--batchsize", default=10, help="putnam has 348")

View File

@ -0,0 +1,29 @@
import collections
from typing import Optional
from pantograph.search import Agent
from pantograph.expr import GoalState, Tactic
class HammerAgent(Agent):
def __init__(self):
super().__init__()
self.goal_tactic_id_map = collections.defaultdict(lambda : 0)
self.tactics = [
"aesop",
]
def next_tactic(
self,
state: GoalState,
goal_id: int,
informal_stmt: str,
informal_proof: str) -> Optional[Tactic]:
key = (state.state_id, goal_id)
i = self.goal_tactic_id_map[key]
if i >= len(self.tactics):
return None
self.goal_tactic_id_map[key] = i + 1
return self.tactics[i]

2
poetry.lock generated
View File

@ -3250,4 +3250,4 @@ test = ["websockets"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.10" python-versions = "^3.10"
content-hash = "686e7f1af124ef2404bc6b46677850c581e8e74f3cab51992fac8e8578f88a3a" content-hash = "d992431714365397c4080f70aa7d146d7819976703bce96637f574856283d704"

View File

@ -26,6 +26,7 @@ openai = "^1.31.0"
tenacity = "8.3.0" tenacity = "8.3.0"
tiktoken = "^0.7.0" tiktoken = "^0.7.0"
wandb = "0.17.0" wandb = "0.17.0"
termcolor = "^2.4.0"
[build-system] [build-system]
requires = ["poetry-core"] requires = ["poetry-core"]