diff --git a/experiments/dsp/main.py b/experiments/dsp/main.py index 92ddcf0..ff800e2 100644 --- a/experiments/dsp/main.py +++ b/experiments/dsp/main.py @@ -1,13 +1,13 @@ import sys, os, json from dataclasses import dataclass from pathlib import Path -from typing import Union, Any -from collections import namedtuple +from typing import Union, Any, Tuple from tqdm import tqdm from openai import OpenAI import wandb from tenacity import retry, stop_after_attempt, wait_exponential from pantograph import Server +from termcolor import colored from solve.prompts import ( extract_lean_code, @@ -19,6 +19,7 @@ from solve.prompts import ( STOP_TOKENS_SKETCH_V0, get_prompt_sketch_template_4_lean_v0, ) +from solve.prove import HammerAgent # prompt_draft_template_lean4_v0 = "Draft an informal solution similar to the one below. The informal solution will be used to sketch a formal proof in the Lean 4 Proof Assistant. Here are some examples of informal problem solutions pairs:\n\nInformal:\n(*### Problem\n\nProve that for any natural number n, n + 0 = n.\n\n### Solution\n\nConsider any natural number n. From properties of addition, adding zero does not change its values. Thus, n + 0 = n.*)\n\nInformal:\n(*### Problem\n\nProve that for any natural number n, n + (m + 1) = (n + m) + 1.\n\n### Solution\n\nConsider any natural numbers n and m. From properties of addition, adding 1 to the sum of n and m is the same as first adding m to n and then adding 1. Thus, n + (m + 1) = (n + m) + 1.*)\n\nInformal:\n(*### Problem\n\nProve that for any natural number n and m, n + m = m + n.\n\n### Solution\n\nConsider any natural numbers n and m. We will do induction on n. Base case: 0 + m = m + 0 by properties of addition. Inductive step, we have n + m = m + n. Then (n + 1) + m = (n + m) + 1 = (m + n) + 1 = m + (n + 1). Thus, by induction, n + m = m + n, qed.*)\n\nInformal: \n(*### Problem\n\n{nl_problem}\n\n### Solution\n" @@ -122,12 +123,12 @@ def draft( @retry(stop=stop_after_attempt(15), wait=wait_exponential(multiplier=2, max=128)) def sketch( - eng, - data_pt: dict, - drafts: list, - autoformalize_prob_in_prompt: bool = False, - verbose: bool = False, - ) -> list: + eng, + data_pt: dict, + drafts: list, + autoformalize_prob_in_prompt: bool = False, + verbose: bool = False, + ) -> Tuple[list[str], str]: """ Creates (formal fl) sketch (fl proof sketch) for latter use in a formal proof sketch. z_pred_fl ~ sketch(eng, x_nl_prob, y_pred_nl, x_fl_prob, P_sketch) @@ -178,13 +179,12 @@ def prove( # If this throws index out of bound errors it means the source doesn't contain walled off Lean sections. lean_code, = [extract_lean_code(sketch)[0] for sketch in fl_sketch] state, = server.load_sorry(lean_code) + agent = HammerAgent() + result = agent.search(server, state, verbose=True) + print(colored(f"Result: {result}", "blue")) - print(state) raise RuntimeError("Not implemented") - # -- Prove - correct: bool = False - # -- Return - return correct + return # -- DSP for Lean @@ -200,10 +200,10 @@ def single_proof_search_dsp_lean( z_fl_pred_sketches, x_fl_prob = sketch(eng, data_pt, y_nl_pred_drafts) # -- Prove: y_fl = prove(eng, x_fl_prob, z_fl_pred_sketches) - correct: bool = prove(eng, server, x_fl_prob, z_fl_pred_sketches) + result: bool = prove(eng, server, x_fl_prob, z_fl_pred_sketches) # -- Return - return correct + return result def full_proof_search_dsp_lean( eng: Engine, @@ -215,6 +215,7 @@ def full_proof_search_dsp_lean( print(f'{len(eval_dataset)=}') # -- Proof search by DSP over all eval data for data_pt in tqdm(eval_dataset, total=len(eval_dataset), desc='DSP proof loop per data point in benchmark.'): + print("Problem:", colored(data_pt["nl_problem"][0], "green", attrs=["underline"])) print(f'{data_pt=}') flag = single_proof_search_dsp_lean(eng, server, data_pt) server.gc() @@ -271,7 +272,8 @@ def main(args): # - Full proof search with DSP print(f'\n\n-- Full proof search with DSP') full_proof_search_dsp_lean(eng, server, path_2_eval_dataset) - print(f"Time taken: {time.time() - start_time:.2f} seconds, or {(time.time() - start_time) / 60:.2f} minutes, or {(time.time() - start_time) / 3600:.2f} hours.\a") + msg = f"Time taken: {time.time() - start_time:.2f} seconds, or {(time.time() - start_time) / 60:.2f} minutes, or {(time.time() - start_time) / 3600:.2f} hours.\a" + print(colored(msg, "magenta")) # - End run # wandb.config.update(config) @@ -296,7 +298,12 @@ if __name__ == "__main__": help="Evaluation dataset path", default=experiment_dir / 'debug/toy_example1_dsp/dsp_debug5_sf/dsp_debug5_sf_train.json', ) - parser.add_argument("--model", help="Model", default="gpt-4o", choices=["gpt2", "gpt-3.5-turbo", "gpt-4o", "deepseek-ai/deepseek-math-7b-instruct"]) + parser.add_argument( + "--model", + help="Model", + default="gpt-4o", + choices=["gpt2", "gpt-3.5-turbo", "gpt-4o", "deepseek-ai/deepseek-math-7b-instruct"], + ) parser.add_argument("--start", default=0) parser.add_argument("--end", default=sys.maxsize) parser.add_argument("--batchsize", default=10, help="putnam has 348") diff --git a/experiments/dsp/solve/prove.py b/experiments/dsp/solve/prove.py new file mode 100644 index 0000000..5ac032a --- /dev/null +++ b/experiments/dsp/solve/prove.py @@ -0,0 +1,29 @@ +import collections +from typing import Optional +from pantograph.search import Agent +from pantograph.expr import GoalState, Tactic + +class HammerAgent(Agent): + + def __init__(self): + super().__init__() + + self.goal_tactic_id_map = collections.defaultdict(lambda : 0) + self.tactics = [ + "aesop", + ] + + def next_tactic( + self, + state: GoalState, + goal_id: int, + informal_stmt: str, + informal_proof: str) -> Optional[Tactic]: + key = (state.state_id, goal_id) + i = self.goal_tactic_id_map[key] + + if i >= len(self.tactics): + return None + + self.goal_tactic_id_map[key] = i + 1 + return self.tactics[i] diff --git a/poetry.lock b/poetry.lock index 179e415..6f5e344 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3250,4 +3250,4 @@ test = ["websockets"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "686e7f1af124ef2404bc6b46677850c581e8e74f3cab51992fac8e8578f88a3a" +content-hash = "d992431714365397c4080f70aa7d146d7819976703bce96637f574856283d704" diff --git a/pyproject.toml b/pyproject.toml index d674e24..0011cbe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ openai = "^1.31.0" tenacity = "8.3.0" tiktoken = "^0.7.0" wandb = "0.17.0" +termcolor = "^2.4.0" [build-system] requires = ["poetry-core"]