feat: Search on minif2f
This commit is contained in:
parent
d94e3086c1
commit
3b76080495
|
@ -16,6 +16,14 @@ Then run `main.py`
|
||||||
python3 experiments/dsp/main.py -h
|
python3 experiments/dsp/main.py -h
|
||||||
```
|
```
|
||||||
|
|
||||||
|
The main command for running DSP is `eval`. Due to the multitude of data format
|
||||||
|
out there, use the `--format` flag to specify the data format. For example,
|
||||||
|
running DSP on minif2f is:
|
||||||
|
|
||||||
|
``` sh
|
||||||
|
python3 experiments/dsp/main.py eval --dataset ../minif2f/valid.jsonl --format minif2f
|
||||||
|
```
|
||||||
|
|
||||||
## Related work
|
## Related work
|
||||||
|
|
||||||
### Tony's AF
|
### Tony's AF
|
||||||
|
|
|
@ -20,6 +20,7 @@ from solve.prompts import (
|
||||||
get_prompt_sketch_template_4_lean_v0,
|
get_prompt_sketch_template_4_lean_v0,
|
||||||
)
|
)
|
||||||
from solve.prove import HammerAgent
|
from solve.prove import HammerAgent
|
||||||
|
from solve.data import Datum
|
||||||
|
|
||||||
# prompt_draft_template_lean4_v0 = "Draft an informal solution similar to the one below. The informal solution will be used to sketch a formal proof in the Lean 4 Proof Assistant. Here are some examples of informal problem solutions pairs:\n\nInformal:\n(*### Problem\n\nProve that for any natural number n, n + 0 = n.\n\n### Solution\n\nConsider any natural number n. From properties of addition, adding zero does not change its values. Thus, n + 0 = n.*)\n\nInformal:\n(*### Problem\n\nProve that for any natural number n, n + (m + 1) = (n + m) + 1.\n\n### Solution\n\nConsider any natural numbers n and m. From properties of addition, adding 1 to the sum of n and m is the same as first adding m to n and then adding 1. Thus, n + (m + 1) = (n + m) + 1.*)\n\nInformal:\n(*### Problem\n\nProve that for any natural number n and m, n + m = m + n.\n\n### Solution\n\nConsider any natural numbers n and m. We will do induction on n. Base case: 0 + m = m + 0 by properties of addition. Inductive step, we have n + m = m + n. Then (n + 1) + m = (n + m) + 1 = (m + n) + 1 = m + (n + 1). Thus, by induction, n + m = m + n, qed.*)\n\nInformal: \n(*### Problem\n\n{nl_problem}\n\n### Solution\n"
|
# prompt_draft_template_lean4_v0 = "Draft an informal solution similar to the one below. The informal solution will be used to sketch a formal proof in the Lean 4 Proof Assistant. Here are some examples of informal problem solutions pairs:\n\nInformal:\n(*### Problem\n\nProve that for any natural number n, n + 0 = n.\n\n### Solution\n\nConsider any natural number n. From properties of addition, adding zero does not change its values. Thus, n + 0 = n.*)\n\nInformal:\n(*### Problem\n\nProve that for any natural number n, n + (m + 1) = (n + m) + 1.\n\n### Solution\n\nConsider any natural numbers n and m. From properties of addition, adding 1 to the sum of n and m is the same as first adding m to n and then adding 1. Thus, n + (m + 1) = (n + m) + 1.*)\n\nInformal:\n(*### Problem\n\nProve that for any natural number n and m, n + m = m + n.\n\n### Solution\n\nConsider any natural numbers n and m. We will do induction on n. Base case: 0 + m = m + 0 by properties of addition. Inductive step, we have n + m = m + n. Then (n + 1) + m = (n + m) + 1 = (m + n) + 1 = m + (n + 1). Thus, by induction, n + m = m + n, qed.*)\n\nInformal: \n(*### Problem\n\n{nl_problem}\n\n### Solution\n"
|
||||||
|
|
||||||
|
@ -83,17 +84,17 @@ class OpenAI_DSP_Engine(Engine):
|
||||||
|
|
||||||
@retry(stop=stop_after_attempt(15), wait=wait_exponential(multiplier=2, max=128))
|
@retry(stop=stop_after_attempt(15), wait=wait_exponential(multiplier=2, max=128))
|
||||||
def autoformalize_prob(
|
def autoformalize_prob(
|
||||||
eng,
|
eng: Engine,
|
||||||
data_pt: dict,
|
datum: Datum,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
):
|
):
|
||||||
""" Autoformalize natural language problem to formal language problem. """
|
""" Autoformalize natural language problem to formal language problem. """
|
||||||
...
|
pass
|
||||||
|
|
||||||
@retry(stop=stop_after_attempt(15), wait=wait_exponential(multiplier=2, max=128))
|
@retry(stop=stop_after_attempt(15), wait=wait_exponential(multiplier=2, max=128))
|
||||||
def draft(
|
def draft(
|
||||||
eng,
|
eng: Engine,
|
||||||
data_pt: dict,
|
datum: Datum,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
) -> list:
|
) -> list:
|
||||||
"""
|
"""
|
||||||
|
@ -101,7 +102,7 @@ def draft(
|
||||||
y_pred_nl ~ draft(eng, x_nl_prob, P_draft)
|
y_pred_nl ~ draft(eng, x_nl_prob, P_draft)
|
||||||
"""
|
"""
|
||||||
# Make prompt from template
|
# Make prompt from template
|
||||||
nl_problem: str = data_pt['nl_problem'][0]
|
nl_problem: str = datum.nl_problem_str
|
||||||
prompt = eng.draft_prompt_template.replace('{nl_problem}', nl_problem)
|
prompt = eng.draft_prompt_template.replace('{nl_problem}', nl_problem)
|
||||||
# Get all **completions** to single prompt, one (in) -> many (out)
|
# Get all **completions** to single prompt, one (in) -> many (out)
|
||||||
# ref: https://platform.openai.com/docs/api-reference/chat/object
|
# ref: https://platform.openai.com/docs/api-reference/chat/object
|
||||||
|
@ -117,15 +118,18 @@ def draft(
|
||||||
stop=eng.draft_sampling_params.stop[:3],
|
stop=eng.draft_sampling_params.stop[:3],
|
||||||
)
|
)
|
||||||
# Get all completions for single prompt
|
# Get all completions for single prompt
|
||||||
completions: list[str] = [completion.message.content for completion in response.choices] # response.choices[i].message
|
completions: list[str] = [
|
||||||
|
completion.message.content
|
||||||
|
for completion in response.choices
|
||||||
|
] # response.choices[i].message
|
||||||
drafts: list[str] = completions
|
drafts: list[str] = completions
|
||||||
return drafts
|
return drafts
|
||||||
|
|
||||||
@retry(stop=stop_after_attempt(15), wait=wait_exponential(multiplier=2, max=128))
|
@retry(stop=stop_after_attempt(15), wait=wait_exponential(multiplier=2, max=128))
|
||||||
def sketch(
|
def sketch(
|
||||||
eng,
|
eng: Engine,
|
||||||
data_pt: dict,
|
datum: Datum,
|
||||||
drafts: list,
|
drafts: list[str],
|
||||||
autoformalize_prob_in_prompt: bool = False,
|
autoformalize_prob_in_prompt: bool = False,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
) -> Tuple[list[str], str]:
|
) -> Tuple[list[str], str]:
|
||||||
|
@ -135,14 +139,14 @@ def sketch(
|
||||||
"""
|
"""
|
||||||
assert len(drafts) == 1, f"For now only 1 draft."
|
assert len(drafts) == 1, f"For now only 1 draft."
|
||||||
# Make prompt from template
|
# Make prompt from template
|
||||||
x_nl_problem: str = data_pt['nl_problem'][0]
|
x_nl_problem: str = datum.nl_problem_str
|
||||||
y_nl_solution: str = drafts[0]
|
y_nl_solution: str = drafts[0]
|
||||||
x_fl_problem = None
|
x_fl_problem = None
|
||||||
if autoformalize_prob_in_prompt:
|
if autoformalize_prob_in_prompt:
|
||||||
# prompt = eng.sketch_prompt_template.replace('{nl_problem}', x_nl_problem).replace('{nl_solution}', y_nl_solution)
|
# prompt = eng.sketch_prompt_template.replace('{nl_problem}', x_nl_problem).replace('{nl_solution}', y_nl_solution)
|
||||||
not NotImplemented
|
not NotImplemented
|
||||||
else:
|
else:
|
||||||
x_fl_problem = data_pt['fl_problem'][0] if 'fl_problem' in data_pt else autoformalize_prob(eng, data_pt)
|
x_fl_problem = datum.fl_problem if datum.fl_problem else autoformalize_prob(eng, datum)
|
||||||
prompt = eng.sketch_prompt_template.replace('{fl_problem}', x_nl_problem).replace('{fl_problem}', y_nl_solution)
|
prompt = eng.sketch_prompt_template.replace('{fl_problem}', x_nl_problem).replace('{fl_problem}', y_nl_solution)
|
||||||
# Get all **completions** to single prompt, one (in) -> many (out), ref: https://platform.openai.com/docs/api-reference/chat/object
|
# Get all **completions** to single prompt, one (in) -> many (out), ref: https://platform.openai.com/docs/api-reference/chat/object
|
||||||
response: Any = eng.llm.chat.completions.create(
|
response: Any = eng.llm.chat.completions.create(
|
||||||
|
@ -177,10 +181,11 @@ def prove(
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# If this throws index out of bound errors it means the source doesn't contain walled off Lean sections.
|
# If this throws index out of bound errors it means the source doesn't contain walled off Lean sections.
|
||||||
|
print(colored("Sketch:", "yellow"), fl_sketch)
|
||||||
lean_code, = [extract_lean_code(sketch)[0] for sketch in fl_sketch]
|
lean_code, = [extract_lean_code(sketch)[0] for sketch in fl_sketch]
|
||||||
state, = server.load_sorry(lean_code)
|
state, = server.load_sorry(lean_code)
|
||||||
agent = HammerAgent()
|
agent = HammerAgent()
|
||||||
result = agent.search(server, state, verbose=True)
|
result = agent.search(server, state)
|
||||||
print(colored(f"Result: {result}", "blue"))
|
print(colored(f"Result: {result}", "blue"))
|
||||||
|
|
||||||
raise RuntimeError("Not implemented")
|
raise RuntimeError("Not implemented")
|
||||||
|
@ -191,13 +196,13 @@ def prove(
|
||||||
def single_proof_search_dsp_lean(
|
def single_proof_search_dsp_lean(
|
||||||
eng: Engine,
|
eng: Engine,
|
||||||
server: Server,
|
server: Server,
|
||||||
data_pt: dict,
|
datum: Datum,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
# -- Draft: [y_nl_pred_draft]_n ~ draft(eng, x_nl_prob, P_draft)
|
# -- Draft: [y_nl_pred_draft]_n ~ draft(eng, x_nl_prob, P_draft)
|
||||||
y_nl_pred_drafts = draft(eng, data_pt)
|
y_nl_pred_drafts = draft(eng, datum)
|
||||||
|
|
||||||
# -- Sketch: z_fl_pred_sketch ~ sketch(eng, x_nl_prob, [y_nl_pred_draft]_n, x_fl_prob, P_sketch)
|
# -- Sketch: z_fl_pred_sketch ~ sketch(eng, x_nl_prob, [y_nl_pred_draft]_n, x_fl_prob, P_sketch)
|
||||||
z_fl_pred_sketches, x_fl_prob = sketch(eng, data_pt, y_nl_pred_drafts)
|
z_fl_pred_sketches, x_fl_prob = sketch(eng, datum, y_nl_pred_drafts)
|
||||||
|
|
||||||
# -- Prove: y_fl = prove(eng, x_fl_prob, z_fl_pred_sketches)
|
# -- Prove: y_fl = prove(eng, x_fl_prob, z_fl_pred_sketches)
|
||||||
result: bool = prove(eng, server, x_fl_prob, z_fl_pred_sketches)
|
result: bool = prove(eng, server, x_fl_prob, z_fl_pred_sketches)
|
||||||
|
@ -208,17 +213,14 @@ def single_proof_search_dsp_lean(
|
||||||
def full_proof_search_dsp_lean(
|
def full_proof_search_dsp_lean(
|
||||||
eng: Engine,
|
eng: Engine,
|
||||||
server: Server,
|
server: Server,
|
||||||
path_2_eval_dataset: Path,
|
data: list[Datum],
|
||||||
):
|
):
|
||||||
# -- Get eval data
|
print(colored(f"DSP on {len(data)} points", "blue", attrs=["bold", "underline"]))
|
||||||
eval_dataset: list[dict] = json.load(open(path_2_eval_dataset, 'r'))
|
|
||||||
print(f'{len(eval_dataset)=}')
|
|
||||||
# -- Proof search by DSP over all eval data
|
# -- Proof search by DSP over all eval data
|
||||||
for data_pt in tqdm(eval_dataset, total=len(eval_dataset), desc='DSP proof loop per data point in benchmark.'):
|
for datum in tqdm(data, total=len(data), desc='DSP proof loop per data point in benchmark.'):
|
||||||
print("Problem:", colored(data_pt["nl_problem"][0], "green", attrs=["underline"]))
|
print("Problem:", colored(datum.id, "cyan"))
|
||||||
print(f'{data_pt=}')
|
flag = single_proof_search_dsp_lean(eng, server, datum)
|
||||||
flag = single_proof_search_dsp_lean(eng, server, data_pt)
|
#server.gc()
|
||||||
server.gc()
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@ -229,14 +231,33 @@ def get_project_and_lean_path():
|
||||||
p = subprocess.check_output(['lake', 'env', 'printenv', 'LEAN_PATH'], cwd=cwd)
|
p = subprocess.check_output(['lake', 'env', 'printenv', 'LEAN_PATH'], cwd=cwd)
|
||||||
return cwd, p
|
return cwd, p
|
||||||
|
|
||||||
|
def load_data(args) -> list[Datum]:
|
||||||
|
p = Path(args.dataset).expanduser()
|
||||||
|
data = None
|
||||||
|
if p.suffix == ".json":
|
||||||
|
data = [
|
||||||
|
Datum.load(obj, data_format=args.format)
|
||||||
|
for obj in json.load(open(p, 'r'))
|
||||||
|
]
|
||||||
|
elif p.suffix == ".jsonl":
|
||||||
|
with open(p, 'r') as f:
|
||||||
|
data = [
|
||||||
|
Datum.load(json.loads(line), data_format=args.format)
|
||||||
|
for line in list(f)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown data suffix: {p.suffix}")
|
||||||
|
data = [datum for datum in data if datum]
|
||||||
|
return data
|
||||||
|
|
||||||
# -- Main
|
# -- Main
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
import time
|
import time
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
path_2_eval_dataset = Path(args.eval_dataset).expanduser()
|
data_eval = load_data(args)
|
||||||
print(f'{path_2_eval_dataset=}')
|
|
||||||
|
|
||||||
|
# Start server
|
||||||
project_path, lean_path = get_project_and_lean_path()
|
project_path, lean_path = get_project_and_lean_path()
|
||||||
server = Server(
|
server = Server(
|
||||||
imports=["Mathlib", "Aesop"],
|
imports=["Mathlib", "Aesop"],
|
||||||
|
@ -281,8 +302,7 @@ def main(args):
|
||||||
)
|
)
|
||||||
|
|
||||||
# - Full proof search with DSP
|
# - Full proof search with DSP
|
||||||
print(f'\n\n-- Full proof search with DSP')
|
full_proof_search_dsp_lean(eng, server, data_eval)
|
||||||
full_proof_search_dsp_lean(eng, server, path_2_eval_dataset)
|
|
||||||
msg = f"Time taken: {time.time() - start_time:.2f} seconds, or {(time.time() - start_time) / 60:.2f} minutes, or {(time.time() - start_time) / 3600:.2f} hours.\a"
|
msg = f"Time taken: {time.time() - start_time:.2f} seconds, or {(time.time() - start_time) / 60:.2f} minutes, or {(time.time() - start_time) / 3600:.2f} hours.\a"
|
||||||
print(colored(msg, "magenta"))
|
print(colored(msg, "magenta"))
|
||||||
|
|
||||||
|
@ -305,7 +325,7 @@ if __name__ == "__main__":
|
||||||
choices=['eval', 'prompts'],
|
choices=['eval', 'prompts'],
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--eval-dataset",
|
"--dataset",
|
||||||
help="Evaluation dataset path",
|
help="Evaluation dataset path",
|
||||||
default=experiment_dir / 'debug/toy_example1_dsp/dsp_debug5_sf/dsp_debug5_sf_train.json',
|
default=experiment_dir / 'debug/toy_example1_dsp/dsp_debug5_sf/dsp_debug5_sf_train.json',
|
||||||
)
|
)
|
||||||
|
@ -315,6 +335,12 @@ if __name__ == "__main__":
|
||||||
default="gpt-4o",
|
default="gpt-4o",
|
||||||
choices=["gpt2", "gpt-3.5-turbo", "gpt-4o", "deepseek-ai/deepseek-math-7b-instruct"],
|
choices=["gpt2", "gpt-3.5-turbo", "gpt-4o", "deepseek-ai/deepseek-math-7b-instruct"],
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--format",
|
||||||
|
help="Data format",
|
||||||
|
default="default",
|
||||||
|
choices=["default", "minif2f"],
|
||||||
|
)
|
||||||
parser.add_argument("--start", default=0)
|
parser.add_argument("--start", default=0)
|
||||||
parser.add_argument("--end", default=sys.maxsize)
|
parser.add_argument("--end", default=sys.maxsize)
|
||||||
parser.add_argument("--batchsize", default=10, help="putnam has 348")
|
parser.add_argument("--batchsize", default=10, help="putnam has 348")
|
||||||
|
|
|
@ -0,0 +1,70 @@
|
||||||
|
import json
|
||||||
|
from typing import Union, Optional
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Datum:
|
||||||
|
"""
|
||||||
|
Represents one theorem proving datapoint.
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: Optional[str] = None
|
||||||
|
|
||||||
|
# Problem and solution in natural language
|
||||||
|
nl_problem: Optional[Union[str, list[str]]] = None
|
||||||
|
nl_solution: Optional[Union[str, list[str]]] = None
|
||||||
|
|
||||||
|
# Problem in formal language
|
||||||
|
fl_problem: Optional[str] = None
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
if self.id:
|
||||||
|
return self.id
|
||||||
|
return str(self.nl_problem)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def nl_problem_str(self) -> Optional[str]:
|
||||||
|
if not self.nl_problem:
|
||||||
|
return None
|
||||||
|
if isinstance(self.nl_problem, list):
|
||||||
|
return "\n".join(self.nl_problem)
|
||||||
|
return self.nl_problem
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_default(obj: dict):
|
||||||
|
"""
|
||||||
|
Loads data in the "default" format
|
||||||
|
"""
|
||||||
|
fl_problem = obj.get("fl_problem")
|
||||||
|
if isinstance(fl_problem, list):
|
||||||
|
fl_problem = "\n".join(fl_problem)
|
||||||
|
return Datum(
|
||||||
|
nl_problem=obj.get("nl_problem"),
|
||||||
|
nl_solution=obj.get("nl_solution"),
|
||||||
|
fl_problem=fl_problem,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_minif2f(obj: dict):
|
||||||
|
"""
|
||||||
|
Loads minif2f data
|
||||||
|
"""
|
||||||
|
fl_problem = obj["formal_statement"].strip()
|
||||||
|
if fl_problem.startswith("--"):
|
||||||
|
return None
|
||||||
|
return Datum(
|
||||||
|
id=obj["id"],
|
||||||
|
fl_problem=fl_problem,
|
||||||
|
#header=obj["header"],
|
||||||
|
nl_problem=obj["informal_stmt"],
|
||||||
|
nl_solution=obj["informal_proof"],
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load(obj: dict, data_format: str):
|
||||||
|
if data_format == "default":
|
||||||
|
return Datum.load_default(obj)
|
||||||
|
elif data_format == "minif2f":
|
||||||
|
return Datum.load_minif2f(obj)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid data format {data_format}")
|
Loading…
Reference in New Issue