Pantograph/examples/lean4_dsp/dsp_for_lean4.py

252 lines
11 KiB
Python
Raw Normal View History

2024-09-26 20:20:32 -07:00
import sys
from collections import namedtuple
2024-07-11 15:49:37 -07:00
import fire
from pathlib import Path
from tqdm import tqdm
from typing import Union, Any
import json
import os
2024-09-26 20:20:32 -07:00
from openai import OpenAI
2024-07-11 15:49:37 -07:00
import wandb
from tenacity import retry, stop_after_attempt, wait_exponential
from solve.dsp_lean_prompts import SYSTEM_PROMPT_DRAFT_V0, prompt_draft_template_lean4_v0, STOP_TOKENS_DRAFT_V0
from solve.dsp_lean_prompts import SYSTEM_PROMPT_SKETCH_V0, prompt_sketch_template_lean4_v0, STOP_TOKENS_SKETCH_V0
2024-09-26 20:20:32 -07:00
# prompt_draft_template_lean4_v0 = "Draft an informal solution similar to the one below. The informal solution will be used to sketch a formal proof in the Lean 4 Proof Assistant. Here are some examples of informal problem solutions pairs:\n\nInformal:\n(*### Problem\n\nProve that for any natural number n, n + 0 = n.\n\n### Solution\n\nConsider any natural number n. From properties of addition, adding zero does not change its values. Thus, n + 0 = n.*)\n\nInformal:\n(*### Problem\n\nProve that for any natural number n, n + (m + 1) = (n + m) + 1.\n\n### Solution\n\nConsider any natural numbers n and m. From properties of addition, adding 1 to the sum of n and m is the same as first adding m to n and then adding 1. Thus, n + (m + 1) = (n + m) + 1.*)\n\nInformal:\n(*### Problem\n\nProve that for any natural number n and m, n + m = m + n.\n\n### Solution\n\nConsider any natural numbers n and m. We will do induction on n. Base case: 0 + m = m + 0 by properties of addition. Inductive step, we have n + m = m + n. Then (n + 1) + m = (n + m) + 1 = (m + n) + 1 = m + (n + 1). Thus, by induction, n + m = m + n, qed.*)\n\nInformal: \n(*### Problem\n\n{nl_problem}\n\n### Solution\n"
2024-07-11 15:49:37 -07:00
class Engine:
def __init__(self):
pass
def __call__(self, *args, **kwards):
pass
class OpenAI_DSP_Engine(Engine):
def __init__(
self,
model: str,
api_key: str = None,
base_url: str = None, # e.g., Mistral-7B-Instrcut-v0.2 on http://120.77.8.29:12345
# Draft Params
draft_system_prompt: str = SYSTEM_PROMPT_DRAFT_V0, # 'You are an expert mathematician and an expert in the Lean 4 Proof Assistant.' (goal do draft)
draft_prompt_template: str = prompt_draft_template_lean4_v0,
2024-09-26 20:20:32 -07:00
draft_sampling_params = None,
2024-07-11 15:49:37 -07:00
draft_stop_tokens: list[str] = STOP_TOKENS_DRAFT_V0,
# Sketch Params
sketch_system_prompt: str = SYSTEM_PROMPT_SKETCH_V0,
sketch_prompt_template: str = prompt_sketch_template_lean4_v0,
2024-09-26 20:20:32 -07:00
sketch_sampling_params = None,
2024-07-11 15:49:37 -07:00
sketch_stop_tokens: list[str] = STOP_TOKENS_SKETCH_V0,
# Prove Params
# ...TODO not sure if needed right now...
# Misc
verbose_init: bool = True,
):
super().__init__()
print(f'{api_key=}, {base_url=}') if verbose_init else None
self.model = model
self.api_key = api_key
self.llm = OpenAI(api_key=self.api_key, base_url=base_url)
# Draft params
self.draft_system_prompt = draft_system_prompt
self.draft_prompt_template = draft_prompt_template
self.draft_sampling_params = draft_sampling_params
2024-09-26 20:20:32 -07:00
# self.draft_sampling_params.stop = draft_stop_tokens
2024-07-11 15:49:37 -07:00
# Sketch params
self.sketch_system_prompt = sketch_system_prompt
self.sketch_prompt_template = sketch_prompt_template
self.sketch_sampling_params = sketch_sampling_params
2024-09-26 20:20:32 -07:00
# self.sketch_sampling_params.stop = sketch_stop_tokens
2024-07-11 15:49:37 -07:00
# Prove params
# ...TODO not sure if needed right now...
@retry(stop=stop_after_attempt(15), wait=wait_exponential(multiplier=2, max=128))
def autoformalize_prob(
eng,
data_pt: dict,
verbose: bool = False,
):
""" Autoformalize natural language problem to formal language problem. """
...
@retry(stop=stop_after_attempt(15), wait=wait_exponential(multiplier=2, max=128))
def draft(
eng,
data_pt: dict,
verbose: bool = False,
) -> list:
"""
Creates (informal nl) draft (nl soln, nl proof sketch) for latter use in a formal proof sketch.
y_pred_nl ~ draft(eng, x_nl_prob, P_draft)
"""
# Make prompt from template
2024-09-26 20:20:32 -07:00
nl_problem: str = data_pt['nl_problem'][0]
2024-07-11 15:49:37 -07:00
prompt = eng.draft_prompt_template.replace('{nl_problem}', nl_problem)
# Get all **completions** to single prompt, one (in) -> many (out)
# ref: https://platform.openai.com/docs/api-reference/chat/object
response: Any = eng.llm.chat.completions.create(
model=eng.model,
messages=[
{"role": "system", "content": eng.draft_system_prompt},
{"role": "user", "content": prompt},
],
temperature=eng.draft_sampling_params.temperature,
top_p=eng.draft_sampling_params.top_p,
n=eng.draft_sampling_params.n,
stop=eng.draft_sampling_params.stop[:3],
)
# Get all completions for single prompt
completions: list[str] = [completion.message.content for completion in response.choices] # response.choices[i].message
drafts: list[str] = completions
return drafts
@retry(stop=stop_after_attempt(15), wait=wait_exponential(multiplier=2, max=128))
def sketch(
eng,
data_pt: dict,
drafts: list,
autoformalize_prob_in_prompt: bool = False,
verbose: bool = False,
) -> list:
"""
Creates (formal fl) sketch (fl proof sketch) for latter use in a formal proof sketch.
z_pred_fl ~ sketch(eng, x_nl_prob, y_pred_nl, x_fl_prob, P_sketch)
"""
assert len(drafts) == 1, f"For now only 1 draft."
# Make prompt from template
2024-09-26 20:20:32 -07:00
x_nl_problem: str = data_pt['nl_problem'][0]
2024-07-11 15:49:37 -07:00
y_nl_solution: str = drafts[0]
if autoformalize_prob_in_prompt:
2024-09-26 20:20:32 -07:00
# prompt = eng.sketch_prompt_template.replace('{nl_problem}', x_nl_problem).replace('{nl_solution}', y_nl_solution)
not NotImplemented
2024-07-11 15:49:37 -07:00
else:
2024-09-26 20:20:32 -07:00
x_fl_problem = data_pt['fl_problem'][0] if 'fl_problem' in data_pt else autoformalize_prob(eng, data_pt)
prompt = eng.sketch_prompt_template.replace('{fl_problem}', x_nl_problem).replace('{fl_problem}', y_nl_solution)
# Get all **completions** to single prompt, one (in) -> many (out), ref: https://platform.openai.com/docs/api-reference/chat/object
2024-07-11 15:49:37 -07:00
response: Any = eng.llm.chat.completions.create(
model=eng.model,
messages=[
{"role": "system", "content": eng.sketch_system_prompt},
{"role": "user", "content": prompt},
],
temperature=eng.sketch_sampling_params.temperature,
top_p=eng.sketch_sampling_params.top_p,
n=eng.sketch_sampling_params.n,
2024-09-26 20:20:32 -07:00
# stop=eng.sketch_sampling_params.stop[:3],
2024-07-11 15:49:37 -07:00
)
# Get all completions for single prompt
completions: list[str] = [completion.message.content for completion in response.choices] # response.choices[i].message
sketches: list[str] = completions
# Return
return sketches, x_fl_problem
def prove(
eng,
fl_prob: str,
fl_sketch: list[str],
):
2024-09-26 20:20:32 -07:00
"""
Complete formal sketch and check if it proves the theorem.
fl_prob --> Lean4 theorem (problem)
fl_sketch --> Lean4 Form Sketch --> have x have ha
"""
2024-07-11 15:49:37 -07:00
# -- Prove
correct: bool = False
# -- Return
return correct
# -- DSP for Lean
def single_proof_search_dsp_lean(
2024-09-26 20:20:32 -07:00
eng: Engine,
data_pt: dict,
2024-07-11 15:49:37 -07:00
) -> bool:
# -- Draft: [y_nl_pred_draft]_n ~ draft(eng, x_nl_prob, P_draft)
y_nl_pred_drafts = draft(eng, data_pt)
# -- Sketch: z_fl_pred_sketch ~ sketch(eng, x_nl_prob, [y_nl_pred_draft]_n, x_fl_prob, P_sketch)
z_fl_pred_sketches, x_fl_prob = sketch(eng, data_pt, y_nl_pred_drafts)
# -- Prove: y_fl = prove(eng, x_fl_prob, z_fl_pred_sketches)
correct: bool = prove(eng, x_fl_prob, z_fl_pred_sketches)
# -- Return
return
def full_proof_search_dsp_lean(
eng: Engine,
path_2_eval_dataset: Union[str, Path],
):
# -- Get eval data
path_2_eval_dataset = Path(path_2_eval_dataset).expanduser()
eval_dataset: list[dict] = json.load(open(path_2_eval_dataset, 'r'))
print(f'{len(eval_dataset)=}')
# -- Proof search by DSP over all eval data
data_pt: dict
2024-09-26 20:20:32 -07:00
for data_pt in tqdm(eval_dataset, total=len(eval_dataset), desc='DSP proof loop per data point in benchmark.'):
print(f'{data_pt=}')
2024-07-11 15:49:37 -07:00
single_proof_search_dsp_lean(eng, data_pt)
return
# -- Main
def main(
2024-09-26 20:20:32 -07:00
path_2_eval_dataset: str = '~/PyPantograph/examples/lean4_dsp/debug/toy_example1_dsp/dsp_debug5_sf/dsp_debug5_sf_train.json',
2024-07-11 15:49:37 -07:00
# model: str = 'deepseek-ai/deepseek-math-7b-instruct',
# model: str = 'gpt2',
2024-09-26 20:20:32 -07:00
# model: str = 'gpt-3.5-turbo',
model: str = 'gpt-4o',
2024-07-11 15:49:37 -07:00
start: int = 0,
end: int = sys.maxsize,
# end: int = 10, # do 10 so enough boxed qs are there
batch_size: int = 10, # putnam has 348
2024-09-26 20:20:32 -07:00
n: int = 1, # num seqs to return for given prompt
2024-07-11 15:49:37 -07:00
max_tokens: int = 2048,
top_p: float = 0.95,
temperature: float = 0.8,
mode: str = "dryrun",
):
path_2_eval_dataset = Path(path_2_eval_dataset).expanduser()
print(f'{path_2_eval_dataset=}')
# - Start wandb run
2024-09-26 20:20:32 -07:00
# print(f'\n\n-- Setup params')
# CUDA_VISIBLE_DEVICES = os.environ.get("CUDA_VISIBLE_DEVICES")
# current_tmux_session = os.environ.get("TMUX", "").split(",")[-1]
# today = datetime.datetime.now().strftime("%Y-m%m-d%d-t%Hh_%Mm_%Ss")
# config = {'today': today, "CUDA_VISIBLE_DEVICES": CUDA_VISIBLE_DEVICES, "current_tmux_session": current_tmux_session, "model": model, "path_2_eval_dataset": path_2_eval_dataset}
# project: str = 'pypantograph'
# run_name = f"{project}: ({config})"
# run = wandb.init(mode=mode, project=project, name=run_name, save_code=True, config=config)
# print(f"{run.url=}")
# print(f'\n Config: \n{config=}')
2024-07-11 15:49:37 -07:00
# - Run DSP for Lean
if 'gpt-4-' in model or 'gpt-3.5-' in model or 'gpt-4o' in model:
api_key = open(Path('~/keys/openai_api_key_brandos_koyejolab.txt').expanduser(), 'r').read().strip()
2024-09-26 20:20:32 -07:00
SamplingParams = namedtuple('SamplingParams', ['n', 'max_tokens', 'top_p', 'temperature', 'stop'])
draft_sampling_params = SamplingParams(n=n, max_tokens=max_tokens, top_p=top_p, temperature=temperature, stop=STOP_TOKENS_DRAFT_V0)
sketch_sampling_params = SamplingParams(n=n, max_tokens=max_tokens, top_p=top_p, temperature=temperature, stop=STOP_TOKENS_SKETCH_V0)
eng: OpenAI_DSP_Engine = OpenAI_DSP_Engine(model=model, api_key=api_key, verbose_init=True, draft_sampling_params=draft_sampling_params, sketch_sampling_params=sketch_sampling_params)
2024-07-11 15:49:37 -07:00
else:
raise ValueError(f"Model {model=} not supported.")
# - Full proof search with DSP
2024-09-26 20:20:32 -07:00
print(f'\n\n-- Full proof search with DSP')
2024-07-11 15:49:37 -07:00
full_proof_search_dsp_lean(eng, path_2_eval_dataset)
# - End run
2024-09-26 20:20:32 -07:00
# wandb.config.update(config)
# print(f"{wandb.config=}")
# run.finish()
2024-07-11 15:49:37 -07:00
if __name__ == "__main__":
import time
start_time = time.time()
fire.Fire(main)
print(f"Time taken: {time.time() - start_time:.2f} seconds, or {(time.time() - start_time) / 60:.2f} minutes, or {(time.time() - start_time) / 3600:.2f} hours.\a")