diff --git a/experiments/dsp/README.md b/experiments/dsp/README.md index e8e17ab..7a587df 100644 --- a/experiments/dsp/README.md +++ b/experiments/dsp/README.md @@ -16,6 +16,14 @@ Then run `main.py` python3 experiments/dsp/main.py -h ``` +The main command for running DSP is `eval`. Due to the multitude of data format +out there, use the `--format` flag to specify the data format. For example, +running DSP on minif2f is: + +``` sh +python3 experiments/dsp/main.py eval --dataset ../minif2f/valid.jsonl --format minif2f +``` + ## Related work ### Tony's AF diff --git a/experiments/dsp/main.py b/experiments/dsp/main.py index cf188df..caee37e 100644 --- a/experiments/dsp/main.py +++ b/experiments/dsp/main.py @@ -20,6 +20,7 @@ from solve.prompts import ( get_prompt_sketch_template_4_lean_v0, ) from solve.prove import HammerAgent +from solve.data import Datum # prompt_draft_template_lean4_v0 = "Draft an informal solution similar to the one below. The informal solution will be used to sketch a formal proof in the Lean 4 Proof Assistant. Here are some examples of informal problem solutions pairs:\n\nInformal:\n(*### Problem\n\nProve that for any natural number n, n + 0 = n.\n\n### Solution\n\nConsider any natural number n. From properties of addition, adding zero does not change its values. Thus, n + 0 = n.*)\n\nInformal:\n(*### Problem\n\nProve that for any natural number n, n + (m + 1) = (n + m) + 1.\n\n### Solution\n\nConsider any natural numbers n and m. From properties of addition, adding 1 to the sum of n and m is the same as first adding m to n and then adding 1. Thus, n + (m + 1) = (n + m) + 1.*)\n\nInformal:\n(*### Problem\n\nProve that for any natural number n and m, n + m = m + n.\n\n### Solution\n\nConsider any natural numbers n and m. We will do induction on n. Base case: 0 + m = m + 0 by properties of addition. Inductive step, we have n + m = m + n. Then (n + 1) + m = (n + m) + 1 = (m + n) + 1 = m + (n + 1). Thus, by induction, n + m = m + n, qed.*)\n\nInformal: \n(*### Problem\n\n{nl_problem}\n\n### Solution\n" @@ -40,25 +41,25 @@ class Engine: class OpenAI_DSP_Engine(Engine): def __init__( - self, - model: str, - api_key: str = None, - base_url: str = None, # e.g., Mistral-7B-Instrcut-v0.2 on http://120.77.8.29:12345 - # Draft Params - draft_system_prompt: str = SYSTEM_PROMPT_DRAFT_V0, # 'You are an expert mathematician and an expert in the Lean 4 Proof Assistant.' (goal do draft) - draft_prompt_template: str = prompt_draft_template_lean4_v0, - draft_sampling_params = None, - draft_stop_tokens: list[str] = STOP_TOKENS_DRAFT_V0, - # Sketch Params - sketch_system_prompt: str = SYSTEM_PROMPT_SKETCH_V0, - sketch_prompt_template: str = prompt_sketch_template_lean4_v0, - sketch_sampling_params = None, - sketch_stop_tokens: list[str] = STOP_TOKENS_SKETCH_V0, - # Prove Params - # ...TODO not sure if needed right now... - # Misc - verbose_init: bool = True, - ): + self, + model: str, + api_key: str = None, + base_url: str = None, # e.g., Mistral-7B-Instrcut-v0.2 on http://120.77.8.29:12345 + # Draft Params + draft_system_prompt: str = SYSTEM_PROMPT_DRAFT_V0, # 'You are an expert mathematician and an expert in the Lean 4 Proof Assistant.' (goal do draft) + draft_prompt_template: str = prompt_draft_template_lean4_v0, + draft_sampling_params = None, + draft_stop_tokens: list[str] = STOP_TOKENS_DRAFT_V0, + # Sketch Params + sketch_system_prompt: str = SYSTEM_PROMPT_SKETCH_V0, + sketch_prompt_template: str = prompt_sketch_template_lean4_v0, + sketch_sampling_params = None, + sketch_stop_tokens: list[str] = STOP_TOKENS_SKETCH_V0, + # Prove Params + # ...TODO not sure if needed right now... + # Misc + verbose_init: bool = True, + ): super().__init__() print(f'{base_url=}') if verbose_init else None @@ -83,25 +84,25 @@ class OpenAI_DSP_Engine(Engine): @retry(stop=stop_after_attempt(15), wait=wait_exponential(multiplier=2, max=128)) def autoformalize_prob( - eng, - data_pt: dict, - verbose: bool = False, -): + eng: Engine, + datum: Datum, + verbose: bool = False, + ): """ Autoformalize natural language problem to formal language problem. """ - ... + pass @retry(stop=stop_after_attempt(15), wait=wait_exponential(multiplier=2, max=128)) def draft( - eng, - data_pt: dict, - verbose: bool = False, + eng: Engine, + datum: Datum, + verbose: bool = False, ) -> list: """ Creates (informal nl) draft (nl soln, nl proof sketch) for latter use in a formal proof sketch. y_pred_nl ~ draft(eng, x_nl_prob, P_draft) """ # Make prompt from template - nl_problem: str = data_pt['nl_problem'][0] + nl_problem: str = datum.nl_problem_str prompt = eng.draft_prompt_template.replace('{nl_problem}', nl_problem) # Get all **completions** to single prompt, one (in) -> many (out) # ref: https://platform.openai.com/docs/api-reference/chat/object @@ -117,15 +118,18 @@ def draft( stop=eng.draft_sampling_params.stop[:3], ) # Get all completions for single prompt - completions: list[str] = [completion.message.content for completion in response.choices] # response.choices[i].message + completions: list[str] = [ + completion.message.content + for completion in response.choices + ] # response.choices[i].message drafts: list[str] = completions return drafts @retry(stop=stop_after_attempt(15), wait=wait_exponential(multiplier=2, max=128)) def sketch( - eng, - data_pt: dict, - drafts: list, + eng: Engine, + datum: Datum, + drafts: list[str], autoformalize_prob_in_prompt: bool = False, verbose: bool = False, ) -> Tuple[list[str], str]: @@ -135,14 +139,14 @@ def sketch( """ assert len(drafts) == 1, f"For now only 1 draft." # Make prompt from template - x_nl_problem: str = data_pt['nl_problem'][0] + x_nl_problem: str = datum.nl_problem_str y_nl_solution: str = drafts[0] x_fl_problem = None if autoformalize_prob_in_prompt: # prompt = eng.sketch_prompt_template.replace('{nl_problem}', x_nl_problem).replace('{nl_solution}', y_nl_solution) not NotImplemented else: - x_fl_problem = data_pt['fl_problem'][0] if 'fl_problem' in data_pt else autoformalize_prob(eng, data_pt) + x_fl_problem = datum.fl_problem if datum.fl_problem else autoformalize_prob(eng, datum) prompt = eng.sketch_prompt_template.replace('{fl_problem}', x_nl_problem).replace('{fl_problem}', y_nl_solution) # Get all **completions** to single prompt, one (in) -> many (out), ref: https://platform.openai.com/docs/api-reference/chat/object response: Any = eng.llm.chat.completions.create( @@ -163,11 +167,11 @@ def sketch( return sketches, x_fl_problem def prove( - eng: Engine, - server: Server, - fl_prob: str, - fl_sketch: list[str], -): + eng: Engine, + server: Server, + fl_prob: str, + fl_sketch: list[str], + ): """ Complete formal sketch and check if it proves the theorem. @@ -177,10 +181,11 @@ def prove( """ # If this throws index out of bound errors it means the source doesn't contain walled off Lean sections. + print(colored("Sketch:", "yellow"), fl_sketch) lean_code, = [extract_lean_code(sketch)[0] for sketch in fl_sketch] state, = server.load_sorry(lean_code) agent = HammerAgent() - result = agent.search(server, state, verbose=True) + result = agent.search(server, state) print(colored(f"Result: {result}", "blue")) raise RuntimeError("Not implemented") @@ -191,13 +196,13 @@ def prove( def single_proof_search_dsp_lean( eng: Engine, server: Server, - data_pt: dict, + datum: Datum, ) -> bool: # -- Draft: [y_nl_pred_draft]_n ~ draft(eng, x_nl_prob, P_draft) - y_nl_pred_drafts = draft(eng, data_pt) + y_nl_pred_drafts = draft(eng, datum) # -- Sketch: z_fl_pred_sketch ~ sketch(eng, x_nl_prob, [y_nl_pred_draft]_n, x_fl_prob, P_sketch) - z_fl_pred_sketches, x_fl_prob = sketch(eng, data_pt, y_nl_pred_drafts) + z_fl_pred_sketches, x_fl_prob = sketch(eng, datum, y_nl_pred_drafts) # -- Prove: y_fl = prove(eng, x_fl_prob, z_fl_pred_sketches) result: bool = prove(eng, server, x_fl_prob, z_fl_pred_sketches) @@ -206,19 +211,16 @@ def single_proof_search_dsp_lean( return result def full_proof_search_dsp_lean( - eng: Engine, - server: Server, - path_2_eval_dataset: Path, -): - # -- Get eval data - eval_dataset: list[dict] = json.load(open(path_2_eval_dataset, 'r')) - print(f'{len(eval_dataset)=}') + eng: Engine, + server: Server, + data: list[Datum], + ): + print(colored(f"DSP on {len(data)} points", "blue", attrs=["bold", "underline"])) # -- Proof search by DSP over all eval data - for data_pt in tqdm(eval_dataset, total=len(eval_dataset), desc='DSP proof loop per data point in benchmark.'): - print("Problem:", colored(data_pt["nl_problem"][0], "green", attrs=["underline"])) - print(f'{data_pt=}') - flag = single_proof_search_dsp_lean(eng, server, data_pt) - server.gc() + for datum in tqdm(data, total=len(data), desc='DSP proof loop per data point in benchmark.'): + print("Problem:", colored(datum.id, "cyan")) + flag = single_proof_search_dsp_lean(eng, server, datum) + #server.gc() return @@ -229,14 +231,33 @@ def get_project_and_lean_path(): p = subprocess.check_output(['lake', 'env', 'printenv', 'LEAN_PATH'], cwd=cwd) return cwd, p +def load_data(args) -> list[Datum]: + p = Path(args.dataset).expanduser() + data = None + if p.suffix == ".json": + data = [ + Datum.load(obj, data_format=args.format) + for obj in json.load(open(p, 'r')) + ] + elif p.suffix == ".jsonl": + with open(p, 'r') as f: + data = [ + Datum.load(json.loads(line), data_format=args.format) + for line in list(f) + ] + else: + raise ValueError(f"Unknown data suffix: {p.suffix}") + data = [datum for datum in data if datum] + return data + # -- Main def main(args): import time start_time = time.time() - path_2_eval_dataset = Path(args.eval_dataset).expanduser() - print(f'{path_2_eval_dataset=}') + data_eval = load_data(args) + # Start server project_path, lean_path = get_project_and_lean_path() server = Server( imports=["Mathlib", "Aesop"], @@ -281,8 +302,7 @@ def main(args): ) # - Full proof search with DSP - print(f'\n\n-- Full proof search with DSP') - full_proof_search_dsp_lean(eng, server, path_2_eval_dataset) + full_proof_search_dsp_lean(eng, server, data_eval) msg = f"Time taken: {time.time() - start_time:.2f} seconds, or {(time.time() - start_time) / 60:.2f} minutes, or {(time.time() - start_time) / 3600:.2f} hours.\a" print(colored(msg, "magenta")) @@ -305,7 +325,7 @@ if __name__ == "__main__": choices=['eval', 'prompts'], ) parser.add_argument( - "--eval-dataset", + "--dataset", help="Evaluation dataset path", default=experiment_dir / 'debug/toy_example1_dsp/dsp_debug5_sf/dsp_debug5_sf_train.json', ) @@ -315,6 +335,12 @@ if __name__ == "__main__": default="gpt-4o", choices=["gpt2", "gpt-3.5-turbo", "gpt-4o", "deepseek-ai/deepseek-math-7b-instruct"], ) + parser.add_argument( + "--format", + help="Data format", + default="default", + choices=["default", "minif2f"], + ) parser.add_argument("--start", default=0) parser.add_argument("--end", default=sys.maxsize) parser.add_argument("--batchsize", default=10, help="putnam has 348") diff --git a/experiments/dsp/solve/data.py b/experiments/dsp/solve/data.py new file mode 100644 index 0000000..64389c6 --- /dev/null +++ b/experiments/dsp/solve/data.py @@ -0,0 +1,70 @@ +import json +from typing import Union, Optional +from dataclasses import dataclass + +@dataclass +class Datum: + """ + Represents one theorem proving datapoint. + """ + + id: Optional[str] = None + + # Problem and solution in natural language + nl_problem: Optional[Union[str, list[str]]] = None + nl_solution: Optional[Union[str, list[str]]] = None + + # Problem in formal language + fl_problem: Optional[str] = None + + def __str__(self): + if self.id: + return self.id + return str(self.nl_problem) + + @property + def nl_problem_str(self) -> Optional[str]: + if not self.nl_problem: + return None + if isinstance(self.nl_problem, list): + return "\n".join(self.nl_problem) + return self.nl_problem + + @staticmethod + def load_default(obj: dict): + """ + Loads data in the "default" format + """ + fl_problem = obj.get("fl_problem") + if isinstance(fl_problem, list): + fl_problem = "\n".join(fl_problem) + return Datum( + nl_problem=obj.get("nl_problem"), + nl_solution=obj.get("nl_solution"), + fl_problem=fl_problem, + ) + + @staticmethod + def load_minif2f(obj: dict): + """ + Loads minif2f data + """ + fl_problem = obj["formal_statement"].strip() + if fl_problem.startswith("--"): + return None + return Datum( + id=obj["id"], + fl_problem=fl_problem, + #header=obj["header"], + nl_problem=obj["informal_stmt"], + nl_solution=obj["informal_proof"], + ) + + @staticmethod + def load(obj: dict, data_format: str): + if data_format == "default": + return Datum.load_default(obj) + elif data_format == "minif2f": + return Datum.load_minif2f(obj) + else: + raise ValueError(f"Invalid data format {data_format}")