refactor: Clarify code in dsp

This commit is contained in:
Leni Aniva 2024-10-02 11:03:00 -07:00
parent e942359666
commit ce2d689b03
Signed by: aniva
GPG Key ID: 4D9B1C8D10EA4C50
1 changed files with 70 additions and 48 deletions

View File

@ -1,4 +1,5 @@
import sys, os, json
from dataclasses import dataclass
from pathlib import Path
from typing import Union, Any
from collections import namedtuple
@ -7,12 +8,21 @@ from tqdm import tqdm
from openai import OpenAI
import wandb
from tenacity import retry, stop_after_attempt, wait_exponential
from pantograph import Server
from solve.dsp_lean_prompts import SYSTEM_PROMPT_DRAFT_V0, prompt_draft_template_lean4_v0, STOP_TOKENS_DRAFT_V0
from solve.dsp_lean_prompts import SYSTEM_PROMPT_SKETCH_V0, prompt_sketch_template_lean4_v0, STOP_TOKENS_SKETCH_V0
# prompt_draft_template_lean4_v0 = "Draft an informal solution similar to the one below. The informal solution will be used to sketch a formal proof in the Lean 4 Proof Assistant. Here are some examples of informal problem solutions pairs:\n\nInformal:\n(*### Problem\n\nProve that for any natural number n, n + 0 = n.\n\n### Solution\n\nConsider any natural number n. From properties of addition, adding zero does not change its values. Thus, n + 0 = n.*)\n\nInformal:\n(*### Problem\n\nProve that for any natural number n, n + (m + 1) = (n + m) + 1.\n\n### Solution\n\nConsider any natural numbers n and m. From properties of addition, adding 1 to the sum of n and m is the same as first adding m to n and then adding 1. Thus, n + (m + 1) = (n + m) + 1.*)\n\nInformal:\n(*### Problem\n\nProve that for any natural number n and m, n + m = m + n.\n\n### Solution\n\nConsider any natural numbers n and m. We will do induction on n. Base case: 0 + m = m + 0 by properties of addition. Inductive step, we have n + m = m + n. Then (n + 1) + m = (n + m) + 1 = (m + n) + 1 = m + (n + 1). Thus, by induction, n + m = m + n, qed.*)\n\nInformal: \n(*### Problem\n\n{nl_problem}\n\n### Solution\n"
@dataclass
class SamplingParams:
n: int
max_tokens: int
top_p: int
temperature: float
stop: str
class Engine:
def __init__(self):
pass
@ -43,6 +53,10 @@ class OpenAI_DSP_Engine(Engine):
):
super().__init__()
print(f'{api_key=}, {base_url=}') if verbose_init else None
if not ('gpt-4-' in model or 'gpt-3.5-' in model or 'gpt-4o' in model):
raise ValueError(f"Model {model=} not supported.")
self.model = model
self.api_key = api_key
self.llm = OpenAI(api_key=self.api_key, base_url=base_url)
@ -133,7 +147,7 @@ def sketch(
top_p=eng.sketch_sampling_params.top_p,
n=eng.sketch_sampling_params.n,
# stop=eng.sketch_sampling_params.stop[:3],
)
)
# Get all completions for single prompt
completions: list[str] = [completion.message.content for completion in response.choices] # response.choices[i].message
sketches: list[str] = completions
@ -153,6 +167,9 @@ def prove(
fl_sketch --> Lean4 Form Sketch --> have x have ha
"""
print(f"fl_prob={fl_prob}")
print(f"fl_sketch={fl_sketch}")
raise RuntimeError("Not implemented")
# -- Prove
correct: bool = False
# -- Return
@ -162,6 +179,7 @@ def prove(
def single_proof_search_dsp_lean(
eng: Engine,
server: Server,
data_pt: dict,
) -> bool:
# -- Draft: [y_nl_pred_draft]_n ~ draft(eng, x_nl_prob, P_draft)
@ -174,28 +192,28 @@ def single_proof_search_dsp_lean(
correct: bool = prove(eng, x_fl_prob, z_fl_pred_sketches)
# -- Return
return
return correct
def full_proof_search_dsp_lean(
eng: Engine,
path_2_eval_dataset: Union[str, Path],
server: Server,
path_2_eval_dataset: Path,
):
# -- Get eval data
path_2_eval_dataset = Path(path_2_eval_dataset).expanduser()
eval_dataset: list[dict] = json.load(open(path_2_eval_dataset, 'r'))
print(f'{len(eval_dataset)=}')
# -- Proof search by DSP over all eval data
data_pt: dict
for data_pt in tqdm(eval_dataset, total=len(eval_dataset), desc='DSP proof loop per data point in benchmark.'):
print(f'{data_pt=}')
single_proof_search_dsp_lean(eng, data_pt)
flag = single_proof_search_dsp_lean(eng, server, data_pt)
return
experiment_dir = Path(__file__).resolve().parent
# -- Main
def main(
path_2_eval_dataset: str = '~/PyPantograph/examples/lean4_dsp/debug/toy_example1_dsp/dsp_debug5_sf/dsp_debug5_sf_train.json',
path_2_eval_dataset: str = experiment_dir / 'debug/toy_example1_dsp/dsp_debug5_sf/dsp_debug5_sf_train.json',
# model: str = 'deepseek-ai/deepseek-math-7b-instruct',
# model: str = 'gpt2',
# model: str = 'gpt-3.5-turbo',
@ -204,7 +222,7 @@ def main(
end: int = sys.maxsize,
# end: int = 10, # do 10 so enough boxed qs are there
batch_size: int = 10, # putnam has 348
n: int = 1, # num seqs to return for given prompt
n_samples: int = 1, # num seqs to return for given prompt
max_tokens: int = 2048,
top_p: float = 0.95,
temperature: float = 0.8,
@ -213,6 +231,8 @@ def main(
path_2_eval_dataset = Path(path_2_eval_dataset).expanduser()
print(f'{path_2_eval_dataset=}')
server = Server()
# - Start wandb run
# print(f'\n\n-- Setup params')
# CUDA_VISIBLE_DEVICES = os.environ.get("CUDA_VISIBLE_DEVICES")
@ -226,18 +246,20 @@ def main(
# print(f'\n Config: \n{config=}')
# - Run DSP for Lean
if 'gpt-4-' in model or 'gpt-3.5-' in model or 'gpt-4o' in model:
api_key = os.environ['OPENAI_API_KEY']
SamplingParams = namedtuple('SamplingParams', ['n', 'max_tokens', 'top_p', 'temperature', 'stop'])
draft_sampling_params = SamplingParams(n=n, max_tokens=max_tokens, top_p=top_p, temperature=temperature, stop=STOP_TOKENS_DRAFT_V0)
sketch_sampling_params = SamplingParams(n=n, max_tokens=max_tokens, top_p=top_p, temperature=temperature, stop=STOP_TOKENS_SKETCH_V0)
eng: OpenAI_DSP_Engine = OpenAI_DSP_Engine(model=model, api_key=api_key, verbose_init=True, draft_sampling_params=draft_sampling_params, sketch_sampling_params=sketch_sampling_params)
else:
raise ValueError(f"Model {model=} not supported.")
api_key = os.environ['OPENAI_API_KEY']
draft_sampling_params = SamplingParams(n=n_samples, max_tokens=max_tokens, top_p=top_p, temperature=temperature, stop=STOP_TOKENS_DRAFT_V0)
sketch_sampling_params = SamplingParams(n=n_samples, max_tokens=max_tokens, top_p=top_p, temperature=temperature, stop=STOP_TOKENS_SKETCH_V0)
eng: OpenAI_DSP_Engine = OpenAI_DSP_Engine(
model=model,
api_key=api_key,
verbose_init=True,
draft_sampling_params=draft_sampling_params,
sketch_sampling_params=sketch_sampling_params,
)
# - Full proof search with DSP
print(f'\n\n-- Full proof search with DSP')
full_proof_search_dsp_lean(eng, path_2_eval_dataset)
full_proof_search_dsp_lean(eng, server, path_2_eval_dataset)
# - End run
# wandb.config.update(config)