Pantograph/experiments/dsp/main.py

472 lines
17 KiB
Python
Raw Normal View History

2024-10-04 18:53:00 -07:00
import sys, os, json, subprocess
2024-10-06 19:14:38 -07:00
from dataclasses import dataclass, asdict, field
from pathlib import Path
2024-10-06 19:14:38 -07:00
from typing import Union, Any, Tuple, Optional
2024-07-11 15:49:37 -07:00
from tqdm import tqdm
2024-09-26 20:20:32 -07:00
from openai import OpenAI
2024-07-11 15:49:37 -07:00
import wandb
from tenacity import retry, stop_after_attempt, wait_exponential
2024-10-06 19:14:38 -07:00
from pantograph import Server, ServerError
from pantograph.search import SearchResult
from termcolor import colored
2024-07-11 15:49:37 -07:00
from solve.prompts import (
extract_lean_code,
SYSTEM_PROMPT_DRAFT_V0,
SYSTEM_PROMPT_SKETCH_V0,
prompt_draft_template_lean4_v0,
prompt_sketch_template_lean4_v0,
STOP_TOKENS_DRAFT_V0,
STOP_TOKENS_SKETCH_V0,
get_prompt_sketch_template_4_lean_v0,
)
from solve.prove import HammerAgent
2024-10-04 21:55:47 -07:00
from solve.data import Datum
2024-07-11 15:49:37 -07:00
2024-09-26 20:20:32 -07:00
# prompt_draft_template_lean4_v0 = "Draft an informal solution similar to the one below. The informal solution will be used to sketch a formal proof in the Lean 4 Proof Assistant. Here are some examples of informal problem solutions pairs:\n\nInformal:\n(*### Problem\n\nProve that for any natural number n, n + 0 = n.\n\n### Solution\n\nConsider any natural number n. From properties of addition, adding zero does not change its values. Thus, n + 0 = n.*)\n\nInformal:\n(*### Problem\n\nProve that for any natural number n, n + (m + 1) = (n + m) + 1.\n\n### Solution\n\nConsider any natural numbers n and m. From properties of addition, adding 1 to the sum of n and m is the same as first adding m to n and then adding 1. Thus, n + (m + 1) = (n + m) + 1.*)\n\nInformal:\n(*### Problem\n\nProve that for any natural number n and m, n + m = m + n.\n\n### Solution\n\nConsider any natural numbers n and m. We will do induction on n. Base case: 0 + m = m + 0 by properties of addition. Inductive step, we have n + m = m + n. Then (n + 1) + m = (n + m) + 1 = (m + n) + 1 = m + (n + 1). Thus, by induction, n + m = m + n, qed.*)\n\nInformal: \n(*### Problem\n\n{nl_problem}\n\n### Solution\n"
2024-10-02 11:03:00 -07:00
@dataclass
class SamplingParams:
n: int
max_tokens: int
top_p: int
temperature: float
stop: str
2024-10-06 19:14:38 -07:00
@dataclass(frozen=True)
class SketchParseFailure:
error: str
sketch: str
@dataclass(frozen=True)
class DatumResult:
"""
Result from one DSP data point
"""
name: str
success: Optional[bool] = False
proves: list[Union[SearchResult, SketchParseFailure]] = field(default_factory=list)
2024-07-11 15:49:37 -07:00
class Engine:
def __init__(self):
pass
def __call__(self, *args, **kwards):
pass
class OpenAI_DSP_Engine(Engine):
def __init__(
2024-10-04 21:55:47 -07:00
self,
model: str,
api_key: str = None,
base_url: str = None, # e.g., Mistral-7B-Instrcut-v0.2 on http://120.77.8.29:12345
# Draft Params
draft_system_prompt: str = SYSTEM_PROMPT_DRAFT_V0, # 'You are an expert mathematician and an expert in the Lean 4 Proof Assistant.' (goal do draft)
draft_prompt_template: str = prompt_draft_template_lean4_v0,
draft_sampling_params = None,
draft_stop_tokens: list[str] = STOP_TOKENS_DRAFT_V0,
# Sketch Params
sketch_system_prompt: str = SYSTEM_PROMPT_SKETCH_V0,
sketch_prompt_template: str = prompt_sketch_template_lean4_v0,
sketch_sampling_params = None,
sketch_stop_tokens: list[str] = STOP_TOKENS_SKETCH_V0,
# Prove Params
# ...TODO not sure if needed right now...
# Misc
verbose_init: bool = True,
):
2024-07-11 15:49:37 -07:00
super().__init__()
print(f'{base_url=}') if verbose_init else None
2024-10-02 11:03:00 -07:00
if not ('gpt-4-' in model or 'gpt-3.5-' in model or 'gpt-4o' in model):
raise ValueError(f"Model {model=} not supported.")
2024-07-11 15:49:37 -07:00
self.model = model
self.api_key = api_key
2024-10-02 11:03:00 -07:00
self.llm = OpenAI(api_key=self.api_key, base_url=base_url)
2024-07-11 15:49:37 -07:00
# Draft params
self.draft_system_prompt = draft_system_prompt
self.draft_prompt_template = draft_prompt_template
self.draft_sampling_params = draft_sampling_params
2024-09-26 20:20:32 -07:00
# self.draft_sampling_params.stop = draft_stop_tokens
2024-07-11 15:49:37 -07:00
# Sketch params
self.sketch_system_prompt = sketch_system_prompt
self.sketch_prompt_template = sketch_prompt_template
self.sketch_sampling_params = sketch_sampling_params
2024-09-26 20:20:32 -07:00
# self.sketch_sampling_params.stop = sketch_stop_tokens
2024-07-11 15:49:37 -07:00
# Prove params
# ...TODO not sure if needed right now...
@retry(stop=stop_after_attempt(15), wait=wait_exponential(multiplier=2, max=128))
def autoformalize_prob(
2024-10-04 21:55:47 -07:00
eng: Engine,
datum: Datum,
verbose: bool = False,
):
2024-07-11 15:49:37 -07:00
""" Autoformalize natural language problem to formal language problem. """
2024-10-04 21:55:47 -07:00
pass
2024-07-11 15:49:37 -07:00
@retry(stop=stop_after_attempt(15), wait=wait_exponential(multiplier=2, max=128))
def draft(
2024-10-04 21:55:47 -07:00
eng: Engine,
datum: Datum,
verbose: bool = False,
2024-07-11 15:49:37 -07:00
) -> list:
2024-10-02 11:03:00 -07:00
"""
Creates (informal nl) draft (nl soln, nl proof sketch) for latter use in a formal proof sketch.
y_pred_nl ~ draft(eng, x_nl_prob, P_draft)
2024-07-11 15:49:37 -07:00
"""
# Make prompt from template
2024-10-04 21:55:47 -07:00
nl_problem: str = datum.nl_problem_str
2024-07-11 15:49:37 -07:00
prompt = eng.draft_prompt_template.replace('{nl_problem}', nl_problem)
# Get all **completions** to single prompt, one (in) -> many (out)
# ref: https://platform.openai.com/docs/api-reference/chat/object
response: Any = eng.llm.chat.completions.create(
model=eng.model,
messages=[
{"role": "system", "content": eng.draft_system_prompt},
{"role": "user", "content": prompt},
],
temperature=eng.draft_sampling_params.temperature,
top_p=eng.draft_sampling_params.top_p,
n=eng.draft_sampling_params.n,
stop=eng.draft_sampling_params.stop[:3],
)
# Get all completions for single prompt
2024-10-04 21:55:47 -07:00
completions: list[str] = [
completion.message.content
for completion in response.choices
] # response.choices[i].message
2024-07-11 15:49:37 -07:00
drafts: list[str] = completions
return drafts
2024-10-02 11:03:00 -07:00
2024-07-11 15:49:37 -07:00
@retry(stop=stop_after_attempt(15), wait=wait_exponential(multiplier=2, max=128))
def sketch(
2024-10-04 21:55:47 -07:00
eng: Engine,
datum: Datum,
drafts: list[str],
autoformalize_prob_in_prompt: bool = False,
verbose: bool = False,
) -> Tuple[list[str], str]:
2024-10-02 11:03:00 -07:00
"""
Creates (formal fl) sketch (fl proof sketch) for latter use in a formal proof sketch.
z_pred_fl ~ sketch(eng, x_nl_prob, y_pred_nl, x_fl_prob, P_sketch)
"""
2024-07-11 15:49:37 -07:00
assert len(drafts) == 1, f"For now only 1 draft."
# Make prompt from template
2024-10-04 21:55:47 -07:00
x_nl_problem: str = datum.nl_problem_str
2024-07-11 15:49:37 -07:00
y_nl_solution: str = drafts[0]
x_fl_problem = None
2024-07-11 15:49:37 -07:00
if autoformalize_prob_in_prompt:
2024-09-26 20:20:32 -07:00
# prompt = eng.sketch_prompt_template.replace('{nl_problem}', x_nl_problem).replace('{nl_solution}', y_nl_solution)
not NotImplemented
2024-07-11 15:49:37 -07:00
else:
2024-10-04 21:55:47 -07:00
x_fl_problem = datum.fl_problem if datum.fl_problem else autoformalize_prob(eng, datum)
2024-09-26 20:20:32 -07:00
prompt = eng.sketch_prompt_template.replace('{fl_problem}', x_nl_problem).replace('{fl_problem}', y_nl_solution)
# Get all **completions** to single prompt, one (in) -> many (out), ref: https://platform.openai.com/docs/api-reference/chat/object
2024-07-11 15:49:37 -07:00
response: Any = eng.llm.chat.completions.create(
model=eng.model,
messages=[
{"role": "system", "content": eng.sketch_system_prompt},
{"role": "user", "content": prompt},
],
temperature=eng.sketch_sampling_params.temperature,
top_p=eng.sketch_sampling_params.top_p,
n=eng.sketch_sampling_params.n,
2024-09-26 20:20:32 -07:00
# stop=eng.sketch_sampling_params.stop[:3],
2024-10-02 11:03:00 -07:00
)
2024-07-11 15:49:37 -07:00
# Get all completions for single prompt
completions: list[str] = [completion.message.content for completion in response.choices] # response.choices[i].message
sketches: list[str] = completions
2024-10-02 11:03:00 -07:00
# Return
2024-07-11 15:49:37 -07:00
return sketches, x_fl_problem
def prove(
2024-10-04 21:55:47 -07:00
eng: Engine,
server: Server,
fl_prob: str,
2024-10-06 19:14:38 -07:00
fl_sketch: str,
) -> list[SearchResult]:
2024-09-26 20:20:32 -07:00
"""
2024-10-02 11:03:00 -07:00
Complete formal sketch and check if it proves the theorem.
fl_prob --> Lean4 theorem (problem)
2024-09-26 20:20:32 -07:00
fl_sketch --> Lean4 Form Sketch --> have x have ha
2024-10-02 11:03:00 -07:00
2024-09-26 20:20:32 -07:00
"""
# If this throws index out of bound errors it means the source doesn't contain walled off Lean sections.
2024-10-04 21:55:47 -07:00
print(colored("Sketch:", "yellow"), fl_sketch)
2024-10-06 19:14:38 -07:00
lean_code = "\n".join(extract_lean_code(fl_sketch))
print(colored("Lean code:", "light_grey"), lean_code)
2024-10-06 19:14:38 -07:00
try:
states = server.load_sorry(lean_code)
except ServerError as e:
msg = f"Encountered exception: {e}"
print(colored(msg, "red"))
return SketchParseFailure(
sketch=fl_sketch,
error=msg,
)
if len(states) != 1:
print(colored("Model must output one compilation unit", "red"))
2024-10-06 19:14:38 -07:00
return SketchParseFailure(
sketch=fl_sketch,
error="Model must output one compilation unit",
)
state = states[0]
if isinstance(state, list) and len(state) > 0:
2024-10-06 19:14:38 -07:00
# This means `state` contains error messages
msg = "\n".join(state)
print(colored("Sketch failed:", "red"), msg)
return SketchParseFailure(
sketch=fl_sketch,
error=f"Sketch failed: {msg}",
)
agent = HammerAgent()
2024-10-06 19:14:38 -07:00
result = agent.search(
server,
state,
max_steps=1000,
max_trials_per_goal=len(agent.tactics) + 1,
)
print(colored(f"Result: {result}", "blue"))
return result
2024-10-02 11:03:00 -07:00
2024-07-11 15:49:37 -07:00
# -- DSP for Lean
def single_proof_search_dsp_lean(
2024-10-02 11:03:00 -07:00
eng: Engine,
2024-10-06 19:14:38 -07:00
server_func,
2024-10-04 21:55:47 -07:00
datum: Datum,
2024-10-06 19:14:38 -07:00
) -> DatumResult:
2024-07-11 15:49:37 -07:00
# -- Draft: [y_nl_pred_draft]_n ~ draft(eng, x_nl_prob, P_draft)
2024-10-04 21:55:47 -07:00
y_nl_pred_drafts = draft(eng, datum)
2024-07-11 15:49:37 -07:00
2024-10-02 11:03:00 -07:00
# -- Sketch: z_fl_pred_sketch ~ sketch(eng, x_nl_prob, [y_nl_pred_draft]_n, x_fl_prob, P_sketch)
2024-10-04 21:55:47 -07:00
z_fl_pred_sketches, x_fl_prob = sketch(eng, datum, y_nl_pred_drafts)
2024-07-11 15:49:37 -07:00
2024-10-06 19:14:38 -07:00
assert len(z_fl_pred_sketches) == 1
2024-07-11 15:49:37 -07:00
2024-10-06 19:14:38 -07:00
server = server_func()
# -- Prove: y_fl = prove(eng, x_fl_prob, z_fl_pred_sketches)
prove_result = [prove(eng, server, x_fl_prob, sketch) for sketch in z_fl_pred_sketches]
return DatumResult(
name=str(datum),
success=any(
x.success for x in prove_result
if isinstance(x, SearchResult)
),
proves=prove_result,
)
2024-07-11 15:49:37 -07:00
def full_proof_search_dsp_lean(
2024-10-04 21:55:47 -07:00
eng: Engine,
2024-10-06 19:14:38 -07:00
server_func,
2024-10-04 21:55:47 -07:00
data: list[Datum],
path_output: Path,
2024-10-04 21:55:47 -07:00
):
print(colored(f"DSP on {len(data)} points", "blue", attrs=["bold", "underline"]))
2024-10-06 19:14:38 -07:00
n_success = 0
n_tried = 0
2024-07-11 15:49:37 -07:00
# -- Proof search by DSP over all eval data
for i, datum in tqdm(enumerate(data), total=len(data), desc='DSP proof loop per data point in benchmark.'):
file_name = path_output / f"{i:03}.json"
2024-10-06 19:14:38 -07:00
key = str(datum)
# Detect if file exists
if file_name.is_file():
obj = json.load(open(file_name, "r"))
if obj['name'] != key:
print(colored(f"Existing datum name {obj['name']} does not match dataset {key}. The output directory may be wrong"))
return
2024-10-06 19:14:38 -07:00
print(f"Skipped {i}:", colored(key, "green"))
continue
n_tried += 1
2024-10-06 19:14:38 -07:00
print(f"Problem {i}:", colored(key, "cyan"))
result = single_proof_search_dsp_lean(eng, server_func, datum)
with open(file_name, 'w') as f:
2024-10-06 19:14:38 -07:00
json.dump(asdict(result), f)
if result.success:
n_success += 1
2024-10-04 21:55:47 -07:00
#server.gc()
print(f"Proved {n_success}/{n_tried} problems")
2024-07-11 15:49:37 -07:00
2024-10-04 18:53:00 -07:00
2024-10-02 11:03:00 -07:00
experiment_dir = Path(__file__).resolve().parent
2024-10-04 18:53:00 -07:00
def get_project_and_lean_path():
cwd = experiment_dir / 'lean_src_proj'
p = subprocess.check_output(['lake', 'env', 'printenv', 'LEAN_PATH'], cwd=cwd)
return cwd, p
2024-10-04 21:55:47 -07:00
def load_data(args) -> list[Datum]:
p = Path(args.dataset).expanduser()
data = None
if p.suffix == ".json":
data = [
Datum.load(obj, data_format=args.format)
for obj in json.load(open(p, 'r'))
]
elif p.suffix == ".jsonl":
with open(p, 'r') as f:
data = [
Datum.load(json.loads(line), data_format=args.format)
for line in list(f)
]
else:
raise ValueError(f"Unknown data suffix: {p.suffix}")
data = [datum for datum in data if datum]
return data
2024-07-11 15:49:37 -07:00
# -- Main
2024-10-03 12:03:33 -07:00
def main(args):
import time, datetime
start_time = time.time()
# Setup paths and data
2024-10-04 21:55:47 -07:00
data_eval = load_data(args)
path_output = Path(args.output)
path_output.mkdir(exist_ok=True, parents=True)
2024-07-11 15:49:37 -07:00
2024-10-04 21:55:47 -07:00
# Start server
2024-10-04 18:53:00 -07:00
project_path, lean_path = get_project_and_lean_path()
2024-10-06 19:14:38 -07:00
def server_func():
return Server(
imports=["Mathlib", "Aesop"],
project_path=project_path,
lean_path=lean_path,
)
2024-10-02 11:03:00 -07:00
2024-07-11 15:49:37 -07:00
# - Start wandb run
2024-09-26 20:20:32 -07:00
# print(f'\n\n-- Setup params')
# CUDA_VISIBLE_DEVICES = os.environ.get("CUDA_VISIBLE_DEVICES")
# current_tmux_session = os.environ.get("TMUX", "").split(",")[-1]
# today = datetime.datetime.now().strftime("%Y-m%m-d%d-t%Hh_%Mm_%Ss")
# config = {'today': today, "CUDA_VISIBLE_DEVICES": CUDA_VISIBLE_DEVICES, "current_tmux_session": current_tmux_session, "model": model, "path_2_eval_dataset": path_2_eval_dataset}
# project: str = 'pypantograph'
# run_name = f"{project}: ({config})"
# run = wandb.init(mode=mode, project=project, name=run_name, save_code=True, config=config)
# print(f"{run.url=}")
# print(f'\n Config: \n{config=}')
2024-07-11 15:49:37 -07:00
# - Run DSP for Lean
2024-10-02 11:03:00 -07:00
api_key = os.environ['OPENAI_API_KEY']
2024-10-03 12:03:33 -07:00
draft_sampling_params = SamplingParams(
n=args.n_samples,
max_tokens=args.max_tokens,
top_p=args.top_p,
temperature=args.temperature,
stop=STOP_TOKENS_DRAFT_V0,
)
sketch_sampling_params = SamplingParams(
n=args.n_samples,
max_tokens=args.max_tokens,
top_p=args.top_p,
temperature=args.temperature,
stop=STOP_TOKENS_SKETCH_V0,
)
2024-10-02 11:03:00 -07:00
eng: OpenAI_DSP_Engine = OpenAI_DSP_Engine(
2024-10-03 12:03:33 -07:00
model=args.model,
2024-10-02 11:03:00 -07:00
api_key=api_key,
verbose_init=True,
draft_sampling_params=draft_sampling_params,
sketch_sampling_params=sketch_sampling_params,
)
2024-07-11 15:49:37 -07:00
# - Full proof search with DSP
2024-10-06 19:14:38 -07:00
full_proof_search_dsp_lean(eng, server_func, data_eval, path_output)
dt = datetime.timedelta(seconds=time.time() - start_time)
print(colored(f"Time elapsed: {dt}", "magenta"))
2024-07-11 15:49:37 -07:00
# - End run
2024-09-26 20:20:32 -07:00
# wandb.config.update(config)
# print(f"{wandb.config=}")
# run.finish()
2024-07-11 15:49:37 -07:00
def stat(args):
path_output = Path(args.output)
data = load_data(args)
n_success = 0
n_tried = 0
for i, datum in tqdm(enumerate(data), total=len(data), desc='DSP proof loop per data point in benchmark.'):
file_name = path_output / f"{i:03}.json"
key = str(datum)
# Detect if file exists
obj = json.load(open(file_name, "r"))
if obj['name'] != key:
print(colored(f"Existing datum name {obj['name']} does not match dataset {key}. The output directory may be wrong"))
return
n_tried += 1
if obj['success']:
n_success += 1
print(f"Proved {n_success}/{n_tried} problems")
2024-07-11 15:49:37 -07:00
if __name__ == "__main__":
2024-10-03 12:03:33 -07:00
import argparse
parser = argparse.ArgumentParser(
prog='DSP',
2024-10-03 12:03:33 -07:00
description="Draft-Sketch-Prove on Lean code",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
'mode',
help="Function",
choices=['eval', 'prompts', 'stat'],
)
parser.add_argument(
2024-10-04 21:55:47 -07:00
"--dataset",
help="Evaluation dataset path",
default=experiment_dir / 'debug/toy_example1_dsp/dsp_debug5_sf/dsp_debug5_sf_train.json',
)
parser.add_argument(
"--output",
help="Result directory",
default=experiment_dir / 'result',
)
parser.add_argument(
"--model",
help="Model",
default="gpt-4o",
choices=["gpt2", "gpt-3.5-turbo", "gpt-4o", "deepseek-ai/deepseek-math-7b-instruct"],
)
2024-10-04 21:55:47 -07:00
parser.add_argument(
"--format",
help="Data format",
default="default",
choices=["default", "minif2f"],
)
parser.add_argument("--start", default=0)
parser.add_argument("--end", default=sys.maxsize)
2024-10-03 12:03:33 -07:00
parser.add_argument("--batchsize", default=10, help="putnam has 348")
parser.add_argument("--n-samples", default=1, help="num seqs to return for given prompt")
parser.add_argument("--max-tokens", default=2048, help="Maximum number of tokens in one sample")
parser.add_argument("--top-p", default=0.95, help="Sampling top p")
parser.add_argument("--temperature", default=0.8, help="Sampling temperature")
parser.add_argument("--verbose", action='store_true')
args = parser.parse_args()
if args.mode == "eval":
2024-10-03 12:03:33 -07:00
main(args)
elif args.mode == 'stat':
stat(args)
elif args.mode == "prompts":
prompt = get_prompt_sketch_template_4_lean_v0(verbose=args.verbose)
print(prompt)
else:
raise ValueError(f"Unknown mode: {args.mode}")