refactor: Clarify code in dsp

This commit is contained in:
Leni Aniva 2024-10-02 11:03:00 -07:00
parent e942359666
commit ce2d689b03
Signed by: aniva
GPG Key ID: 4D9B1C8D10EA4C50
1 changed files with 70 additions and 48 deletions

View File

@ -1,4 +1,5 @@
import sys, os, json import sys, os, json
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Union, Any from typing import Union, Any
from collections import namedtuple from collections import namedtuple
@ -7,12 +8,21 @@ from tqdm import tqdm
from openai import OpenAI from openai import OpenAI
import wandb import wandb
from tenacity import retry, stop_after_attempt, wait_exponential from tenacity import retry, stop_after_attempt, wait_exponential
from pantograph import Server
from solve.dsp_lean_prompts import SYSTEM_PROMPT_DRAFT_V0, prompt_draft_template_lean4_v0, STOP_TOKENS_DRAFT_V0 from solve.dsp_lean_prompts import SYSTEM_PROMPT_DRAFT_V0, prompt_draft_template_lean4_v0, STOP_TOKENS_DRAFT_V0
from solve.dsp_lean_prompts import SYSTEM_PROMPT_SKETCH_V0, prompt_sketch_template_lean4_v0, STOP_TOKENS_SKETCH_V0 from solve.dsp_lean_prompts import SYSTEM_PROMPT_SKETCH_V0, prompt_sketch_template_lean4_v0, STOP_TOKENS_SKETCH_V0
# prompt_draft_template_lean4_v0 = "Draft an informal solution similar to the one below. The informal solution will be used to sketch a formal proof in the Lean 4 Proof Assistant. Here are some examples of informal problem solutions pairs:\n\nInformal:\n(*### Problem\n\nProve that for any natural number n, n + 0 = n.\n\n### Solution\n\nConsider any natural number n. From properties of addition, adding zero does not change its values. Thus, n + 0 = n.*)\n\nInformal:\n(*### Problem\n\nProve that for any natural number n, n + (m + 1) = (n + m) + 1.\n\n### Solution\n\nConsider any natural numbers n and m. From properties of addition, adding 1 to the sum of n and m is the same as first adding m to n and then adding 1. Thus, n + (m + 1) = (n + m) + 1.*)\n\nInformal:\n(*### Problem\n\nProve that for any natural number n and m, n + m = m + n.\n\n### Solution\n\nConsider any natural numbers n and m. We will do induction on n. Base case: 0 + m = m + 0 by properties of addition. Inductive step, we have n + m = m + n. Then (n + 1) + m = (n + m) + 1 = (m + n) + 1 = m + (n + 1). Thus, by induction, n + m = m + n, qed.*)\n\nInformal: \n(*### Problem\n\n{nl_problem}\n\n### Solution\n" # prompt_draft_template_lean4_v0 = "Draft an informal solution similar to the one below. The informal solution will be used to sketch a formal proof in the Lean 4 Proof Assistant. Here are some examples of informal problem solutions pairs:\n\nInformal:\n(*### Problem\n\nProve that for any natural number n, n + 0 = n.\n\n### Solution\n\nConsider any natural number n. From properties of addition, adding zero does not change its values. Thus, n + 0 = n.*)\n\nInformal:\n(*### Problem\n\nProve that for any natural number n, n + (m + 1) = (n + m) + 1.\n\n### Solution\n\nConsider any natural numbers n and m. From properties of addition, adding 1 to the sum of n and m is the same as first adding m to n and then adding 1. Thus, n + (m + 1) = (n + m) + 1.*)\n\nInformal:\n(*### Problem\n\nProve that for any natural number n and m, n + m = m + n.\n\n### Solution\n\nConsider any natural numbers n and m. We will do induction on n. Base case: 0 + m = m + 0 by properties of addition. Inductive step, we have n + m = m + n. Then (n + 1) + m = (n + m) + 1 = (m + n) + 1 = m + (n + 1). Thus, by induction, n + m = m + n, qed.*)\n\nInformal: \n(*### Problem\n\n{nl_problem}\n\n### Solution\n"
@dataclass
class SamplingParams:
n: int
max_tokens: int
top_p: int
temperature: float
stop: str
class Engine: class Engine:
def __init__(self): def __init__(self):
pass pass
@ -22,10 +32,10 @@ class Engine:
class OpenAI_DSP_Engine(Engine): class OpenAI_DSP_Engine(Engine):
def __init__( def __init__(
self, self,
model: str, model: str,
api_key: str = None, api_key: str = None,
base_url: str = None, # e.g., Mistral-7B-Instrcut-v0.2 on http://120.77.8.29:12345 base_url: str = None, # e.g., Mistral-7B-Instrcut-v0.2 on http://120.77.8.29:12345
# Draft Params # Draft Params
draft_system_prompt: str = SYSTEM_PROMPT_DRAFT_V0, # 'You are an expert mathematician and an expert in the Lean 4 Proof Assistant.' (goal do draft) draft_system_prompt: str = SYSTEM_PROMPT_DRAFT_V0, # 'You are an expert mathematician and an expert in the Lean 4 Proof Assistant.' (goal do draft)
draft_prompt_template: str = prompt_draft_template_lean4_v0, draft_prompt_template: str = prompt_draft_template_lean4_v0,
@ -43,9 +53,13 @@ class OpenAI_DSP_Engine(Engine):
): ):
super().__init__() super().__init__()
print(f'{api_key=}, {base_url=}') if verbose_init else None print(f'{api_key=}, {base_url=}') if verbose_init else None
if not ('gpt-4-' in model or 'gpt-3.5-' in model or 'gpt-4o' in model):
raise ValueError(f"Model {model=} not supported.")
self.model = model self.model = model
self.api_key = api_key self.api_key = api_key
self.llm = OpenAI(api_key=self.api_key, base_url=base_url) self.llm = OpenAI(api_key=self.api_key, base_url=base_url)
# Draft params # Draft params
self.draft_system_prompt = draft_system_prompt self.draft_system_prompt = draft_system_prompt
self.draft_prompt_template = draft_prompt_template self.draft_prompt_template = draft_prompt_template
@ -61,7 +75,7 @@ class OpenAI_DSP_Engine(Engine):
@retry(stop=stop_after_attempt(15), wait=wait_exponential(multiplier=2, max=128)) @retry(stop=stop_after_attempt(15), wait=wait_exponential(multiplier=2, max=128))
def autoformalize_prob( def autoformalize_prob(
eng, eng,
data_pt: dict, data_pt: dict,
verbose: bool = False, verbose: bool = False,
): ):
@ -70,13 +84,13 @@ def autoformalize_prob(
@retry(stop=stop_after_attempt(15), wait=wait_exponential(multiplier=2, max=128)) @retry(stop=stop_after_attempt(15), wait=wait_exponential(multiplier=2, max=128))
def draft( def draft(
eng, eng,
data_pt: dict, data_pt: dict,
verbose: bool = False, verbose: bool = False,
) -> list: ) -> list:
""" """
Creates (informal nl) draft (nl soln, nl proof sketch) for latter use in a formal proof sketch. Creates (informal nl) draft (nl soln, nl proof sketch) for latter use in a formal proof sketch.
y_pred_nl ~ draft(eng, x_nl_prob, P_draft) y_pred_nl ~ draft(eng, x_nl_prob, P_draft)
""" """
# Make prompt from template # Make prompt from template
nl_problem: str = data_pt['nl_problem'][0] nl_problem: str = data_pt['nl_problem'][0]
@ -98,7 +112,7 @@ def draft(
completions: list[str] = [completion.message.content for completion in response.choices] # response.choices[i].message completions: list[str] = [completion.message.content for completion in response.choices] # response.choices[i].message
drafts: list[str] = completions drafts: list[str] = completions
return drafts return drafts
@retry(stop=stop_after_attempt(15), wait=wait_exponential(multiplier=2, max=128)) @retry(stop=stop_after_attempt(15), wait=wait_exponential(multiplier=2, max=128))
def sketch( def sketch(
eng, eng,
@ -107,10 +121,10 @@ def sketch(
autoformalize_prob_in_prompt: bool = False, autoformalize_prob_in_prompt: bool = False,
verbose: bool = False, verbose: bool = False,
) -> list: ) -> list:
""" """
Creates (formal fl) sketch (fl proof sketch) for latter use in a formal proof sketch. Creates (formal fl) sketch (fl proof sketch) for latter use in a formal proof sketch.
z_pred_fl ~ sketch(eng, x_nl_prob, y_pred_nl, x_fl_prob, P_sketch) z_pred_fl ~ sketch(eng, x_nl_prob, y_pred_nl, x_fl_prob, P_sketch)
""" """
assert len(drafts) == 1, f"For now only 1 draft." assert len(drafts) == 1, f"For now only 1 draft."
# Make prompt from template # Make prompt from template
x_nl_problem: str = data_pt['nl_problem'][0] x_nl_problem: str = data_pt['nl_problem'][0]
@ -133,86 +147,92 @@ def sketch(
top_p=eng.sketch_sampling_params.top_p, top_p=eng.sketch_sampling_params.top_p,
n=eng.sketch_sampling_params.n, n=eng.sketch_sampling_params.n,
# stop=eng.sketch_sampling_params.stop[:3], # stop=eng.sketch_sampling_params.stop[:3],
) )
# Get all completions for single prompt # Get all completions for single prompt
completions: list[str] = [completion.message.content for completion in response.choices] # response.choices[i].message completions: list[str] = [completion.message.content for completion in response.choices] # response.choices[i].message
sketches: list[str] = completions sketches: list[str] = completions
# Return # Return
return sketches, x_fl_problem return sketches, x_fl_problem
def prove( def prove(
eng, eng,
fl_prob: str, fl_prob: str,
fl_sketch: list[str], fl_sketch: list[str],
): ):
""" """
Complete formal sketch and check if it proves the theorem.
fl_prob --> Lean4 theorem (problem) Complete formal sketch and check if it proves the theorem.
fl_prob --> Lean4 theorem (problem)
fl_sketch --> Lean4 Form Sketch --> have x have ha fl_sketch --> Lean4 Form Sketch --> have x have ha
""" """
print(f"fl_prob={fl_prob}")
print(f"fl_sketch={fl_sketch}")
raise RuntimeError("Not implemented")
# -- Prove # -- Prove
correct: bool = False correct: bool = False
# -- Return # -- Return
return correct return correct
# -- DSP for Lean # -- DSP for Lean
def single_proof_search_dsp_lean( def single_proof_search_dsp_lean(
eng: Engine, eng: Engine,
server: Server,
data_pt: dict, data_pt: dict,
) -> bool: ) -> bool:
# -- Draft: [y_nl_pred_draft]_n ~ draft(eng, x_nl_prob, P_draft) # -- Draft: [y_nl_pred_draft]_n ~ draft(eng, x_nl_prob, P_draft)
y_nl_pred_drafts = draft(eng, data_pt) y_nl_pred_drafts = draft(eng, data_pt)
# -- Sketch: z_fl_pred_sketch ~ sketch(eng, x_nl_prob, [y_nl_pred_draft]_n, x_fl_prob, P_sketch) # -- Sketch: z_fl_pred_sketch ~ sketch(eng, x_nl_prob, [y_nl_pred_draft]_n, x_fl_prob, P_sketch)
z_fl_pred_sketches, x_fl_prob = sketch(eng, data_pt, y_nl_pred_drafts) z_fl_pred_sketches, x_fl_prob = sketch(eng, data_pt, y_nl_pred_drafts)
# -- Prove: y_fl = prove(eng, x_fl_prob, z_fl_pred_sketches) # -- Prove: y_fl = prove(eng, x_fl_prob, z_fl_pred_sketches)
correct: bool = prove(eng, x_fl_prob, z_fl_pred_sketches) correct: bool = prove(eng, x_fl_prob, z_fl_pred_sketches)
# -- Return # -- Return
return return correct
def full_proof_search_dsp_lean( def full_proof_search_dsp_lean(
eng: Engine, eng: Engine,
path_2_eval_dataset: Union[str, Path], server: Server,
path_2_eval_dataset: Path,
): ):
# -- Get eval data # -- Get eval data
path_2_eval_dataset = Path(path_2_eval_dataset).expanduser()
eval_dataset: list[dict] = json.load(open(path_2_eval_dataset, 'r')) eval_dataset: list[dict] = json.load(open(path_2_eval_dataset, 'r'))
print(f'{len(eval_dataset)=}') print(f'{len(eval_dataset)=}')
# -- Proof search by DSP over all eval data # -- Proof search by DSP over all eval data
data_pt: dict
for data_pt in tqdm(eval_dataset, total=len(eval_dataset), desc='DSP proof loop per data point in benchmark.'): for data_pt in tqdm(eval_dataset, total=len(eval_dataset), desc='DSP proof loop per data point in benchmark.'):
print(f'{data_pt=}') print(f'{data_pt=}')
single_proof_search_dsp_lean(eng, data_pt) flag = single_proof_search_dsp_lean(eng, server, data_pt)
return return
experiment_dir = Path(__file__).resolve().parent
# -- Main # -- Main
def main( def main(
path_2_eval_dataset: str = '~/PyPantograph/examples/lean4_dsp/debug/toy_example1_dsp/dsp_debug5_sf/dsp_debug5_sf_train.json', path_2_eval_dataset: str = experiment_dir / 'debug/toy_example1_dsp/dsp_debug5_sf/dsp_debug5_sf_train.json',
# model: str = 'deepseek-ai/deepseek-math-7b-instruct', # model: str = 'deepseek-ai/deepseek-math-7b-instruct',
# model: str = 'gpt2', # model: str = 'gpt2',
# model: str = 'gpt-3.5-turbo', # model: str = 'gpt-3.5-turbo',
model: str = 'gpt-4o', model: str = 'gpt-4o',
start: int = 0, start: int = 0,
end: int = sys.maxsize, end: int = sys.maxsize,
# end: int = 10, # do 10 so enough boxed qs are there # end: int = 10, # do 10 so enough boxed qs are there
batch_size: int = 10, # putnam has 348 batch_size: int = 10, # putnam has 348
n: int = 1, # num seqs to return for given prompt n_samples: int = 1, # num seqs to return for given prompt
max_tokens: int = 2048, max_tokens: int = 2048,
top_p: float = 0.95, top_p: float = 0.95,
temperature: float = 0.8, temperature: float = 0.8,
mode: str = "dryrun", mode: str = "dryrun",
): ):
path_2_eval_dataset = Path(path_2_eval_dataset).expanduser() path_2_eval_dataset = Path(path_2_eval_dataset).expanduser()
print(f'{path_2_eval_dataset=}') print(f'{path_2_eval_dataset=}')
server = Server()
# - Start wandb run # - Start wandb run
# print(f'\n\n-- Setup params') # print(f'\n\n-- Setup params')
# CUDA_VISIBLE_DEVICES = os.environ.get("CUDA_VISIBLE_DEVICES") # CUDA_VISIBLE_DEVICES = os.environ.get("CUDA_VISIBLE_DEVICES")
@ -226,18 +246,20 @@ def main(
# print(f'\n Config: \n{config=}') # print(f'\n Config: \n{config=}')
# - Run DSP for Lean # - Run DSP for Lean
if 'gpt-4-' in model or 'gpt-3.5-' in model or 'gpt-4o' in model: api_key = os.environ['OPENAI_API_KEY']
api_key = os.environ['OPENAI_API_KEY'] draft_sampling_params = SamplingParams(n=n_samples, max_tokens=max_tokens, top_p=top_p, temperature=temperature, stop=STOP_TOKENS_DRAFT_V0)
SamplingParams = namedtuple('SamplingParams', ['n', 'max_tokens', 'top_p', 'temperature', 'stop']) sketch_sampling_params = SamplingParams(n=n_samples, max_tokens=max_tokens, top_p=top_p, temperature=temperature, stop=STOP_TOKENS_SKETCH_V0)
draft_sampling_params = SamplingParams(n=n, max_tokens=max_tokens, top_p=top_p, temperature=temperature, stop=STOP_TOKENS_DRAFT_V0) eng: OpenAI_DSP_Engine = OpenAI_DSP_Engine(
sketch_sampling_params = SamplingParams(n=n, max_tokens=max_tokens, top_p=top_p, temperature=temperature, stop=STOP_TOKENS_SKETCH_V0) model=model,
eng: OpenAI_DSP_Engine = OpenAI_DSP_Engine(model=model, api_key=api_key, verbose_init=True, draft_sampling_params=draft_sampling_params, sketch_sampling_params=sketch_sampling_params) api_key=api_key,
else: verbose_init=True,
raise ValueError(f"Model {model=} not supported.") draft_sampling_params=draft_sampling_params,
sketch_sampling_params=sketch_sampling_params,
)
# - Full proof search with DSP # - Full proof search with DSP
print(f'\n\n-- Full proof search with DSP') print(f'\n\n-- Full proof search with DSP')
full_proof_search_dsp_lean(eng, path_2_eval_dataset) full_proof_search_dsp_lean(eng, server, path_2_eval_dataset)
# - End run # - End run
# wandb.config.update(config) # wandb.config.update(config)