refactor: Clarify code in dsp
This commit is contained in:
parent
e942359666
commit
ce2d689b03
|
@ -1,4 +1,5 @@
|
||||||
import sys, os, json
|
import sys, os, json
|
||||||
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union, Any
|
from typing import Union, Any
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
@ -7,12 +8,21 @@ from tqdm import tqdm
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
import wandb
|
import wandb
|
||||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||||
|
from pantograph import Server
|
||||||
|
|
||||||
from solve.dsp_lean_prompts import SYSTEM_PROMPT_DRAFT_V0, prompt_draft_template_lean4_v0, STOP_TOKENS_DRAFT_V0
|
from solve.dsp_lean_prompts import SYSTEM_PROMPT_DRAFT_V0, prompt_draft_template_lean4_v0, STOP_TOKENS_DRAFT_V0
|
||||||
from solve.dsp_lean_prompts import SYSTEM_PROMPT_SKETCH_V0, prompt_sketch_template_lean4_v0, STOP_TOKENS_SKETCH_V0
|
from solve.dsp_lean_prompts import SYSTEM_PROMPT_SKETCH_V0, prompt_sketch_template_lean4_v0, STOP_TOKENS_SKETCH_V0
|
||||||
|
|
||||||
# prompt_draft_template_lean4_v0 = "Draft an informal solution similar to the one below. The informal solution will be used to sketch a formal proof in the Lean 4 Proof Assistant. Here are some examples of informal problem solutions pairs:\n\nInformal:\n(*### Problem\n\nProve that for any natural number n, n + 0 = n.\n\n### Solution\n\nConsider any natural number n. From properties of addition, adding zero does not change its values. Thus, n + 0 = n.*)\n\nInformal:\n(*### Problem\n\nProve that for any natural number n, n + (m + 1) = (n + m) + 1.\n\n### Solution\n\nConsider any natural numbers n and m. From properties of addition, adding 1 to the sum of n and m is the same as first adding m to n and then adding 1. Thus, n + (m + 1) = (n + m) + 1.*)\n\nInformal:\n(*### Problem\n\nProve that for any natural number n and m, n + m = m + n.\n\n### Solution\n\nConsider any natural numbers n and m. We will do induction on n. Base case: 0 + m = m + 0 by properties of addition. Inductive step, we have n + m = m + n. Then (n + 1) + m = (n + m) + 1 = (m + n) + 1 = m + (n + 1). Thus, by induction, n + m = m + n, qed.*)\n\nInformal: \n(*### Problem\n\n{nl_problem}\n\n### Solution\n"
|
# prompt_draft_template_lean4_v0 = "Draft an informal solution similar to the one below. The informal solution will be used to sketch a formal proof in the Lean 4 Proof Assistant. Here are some examples of informal problem solutions pairs:\n\nInformal:\n(*### Problem\n\nProve that for any natural number n, n + 0 = n.\n\n### Solution\n\nConsider any natural number n. From properties of addition, adding zero does not change its values. Thus, n + 0 = n.*)\n\nInformal:\n(*### Problem\n\nProve that for any natural number n, n + (m + 1) = (n + m) + 1.\n\n### Solution\n\nConsider any natural numbers n and m. From properties of addition, adding 1 to the sum of n and m is the same as first adding m to n and then adding 1. Thus, n + (m + 1) = (n + m) + 1.*)\n\nInformal:\n(*### Problem\n\nProve that for any natural number n and m, n + m = m + n.\n\n### Solution\n\nConsider any natural numbers n and m. We will do induction on n. Base case: 0 + m = m + 0 by properties of addition. Inductive step, we have n + m = m + n. Then (n + 1) + m = (n + m) + 1 = (m + n) + 1 = m + (n + 1). Thus, by induction, n + m = m + n, qed.*)\n\nInformal: \n(*### Problem\n\n{nl_problem}\n\n### Solution\n"
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SamplingParams:
|
||||||
|
n: int
|
||||||
|
max_tokens: int
|
||||||
|
top_p: int
|
||||||
|
temperature: float
|
||||||
|
stop: str
|
||||||
|
|
||||||
class Engine:
|
class Engine:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
@ -22,10 +32,10 @@ class Engine:
|
||||||
|
|
||||||
class OpenAI_DSP_Engine(Engine):
|
class OpenAI_DSP_Engine(Engine):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
api_key: str = None,
|
api_key: str = None,
|
||||||
base_url: str = None, # e.g., Mistral-7B-Instrcut-v0.2 on http://120.77.8.29:12345
|
base_url: str = None, # e.g., Mistral-7B-Instrcut-v0.2 on http://120.77.8.29:12345
|
||||||
# Draft Params
|
# Draft Params
|
||||||
draft_system_prompt: str = SYSTEM_PROMPT_DRAFT_V0, # 'You are an expert mathematician and an expert in the Lean 4 Proof Assistant.' (goal do draft)
|
draft_system_prompt: str = SYSTEM_PROMPT_DRAFT_V0, # 'You are an expert mathematician and an expert in the Lean 4 Proof Assistant.' (goal do draft)
|
||||||
draft_prompt_template: str = prompt_draft_template_lean4_v0,
|
draft_prompt_template: str = prompt_draft_template_lean4_v0,
|
||||||
|
@ -43,9 +53,13 @@ class OpenAI_DSP_Engine(Engine):
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
print(f'{api_key=}, {base_url=}') if verbose_init else None
|
print(f'{api_key=}, {base_url=}') if verbose_init else None
|
||||||
|
|
||||||
|
|
||||||
|
if not ('gpt-4-' in model or 'gpt-3.5-' in model or 'gpt-4o' in model):
|
||||||
|
raise ValueError(f"Model {model=} not supported.")
|
||||||
self.model = model
|
self.model = model
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.llm = OpenAI(api_key=self.api_key, base_url=base_url)
|
self.llm = OpenAI(api_key=self.api_key, base_url=base_url)
|
||||||
# Draft params
|
# Draft params
|
||||||
self.draft_system_prompt = draft_system_prompt
|
self.draft_system_prompt = draft_system_prompt
|
||||||
self.draft_prompt_template = draft_prompt_template
|
self.draft_prompt_template = draft_prompt_template
|
||||||
|
@ -61,7 +75,7 @@ class OpenAI_DSP_Engine(Engine):
|
||||||
|
|
||||||
@retry(stop=stop_after_attempt(15), wait=wait_exponential(multiplier=2, max=128))
|
@retry(stop=stop_after_attempt(15), wait=wait_exponential(multiplier=2, max=128))
|
||||||
def autoformalize_prob(
|
def autoformalize_prob(
|
||||||
eng,
|
eng,
|
||||||
data_pt: dict,
|
data_pt: dict,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
):
|
):
|
||||||
|
@ -70,13 +84,13 @@ def autoformalize_prob(
|
||||||
|
|
||||||
@retry(stop=stop_after_attempt(15), wait=wait_exponential(multiplier=2, max=128))
|
@retry(stop=stop_after_attempt(15), wait=wait_exponential(multiplier=2, max=128))
|
||||||
def draft(
|
def draft(
|
||||||
eng,
|
eng,
|
||||||
data_pt: dict,
|
data_pt: dict,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
) -> list:
|
) -> list:
|
||||||
"""
|
"""
|
||||||
Creates (informal nl) draft (nl soln, nl proof sketch) for latter use in a formal proof sketch.
|
Creates (informal nl) draft (nl soln, nl proof sketch) for latter use in a formal proof sketch.
|
||||||
y_pred_nl ~ draft(eng, x_nl_prob, P_draft)
|
y_pred_nl ~ draft(eng, x_nl_prob, P_draft)
|
||||||
"""
|
"""
|
||||||
# Make prompt from template
|
# Make prompt from template
|
||||||
nl_problem: str = data_pt['nl_problem'][0]
|
nl_problem: str = data_pt['nl_problem'][0]
|
||||||
|
@ -98,7 +112,7 @@ def draft(
|
||||||
completions: list[str] = [completion.message.content for completion in response.choices] # response.choices[i].message
|
completions: list[str] = [completion.message.content for completion in response.choices] # response.choices[i].message
|
||||||
drafts: list[str] = completions
|
drafts: list[str] = completions
|
||||||
return drafts
|
return drafts
|
||||||
|
|
||||||
@retry(stop=stop_after_attempt(15), wait=wait_exponential(multiplier=2, max=128))
|
@retry(stop=stop_after_attempt(15), wait=wait_exponential(multiplier=2, max=128))
|
||||||
def sketch(
|
def sketch(
|
||||||
eng,
|
eng,
|
||||||
|
@ -107,10 +121,10 @@ def sketch(
|
||||||
autoformalize_prob_in_prompt: bool = False,
|
autoformalize_prob_in_prompt: bool = False,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
) -> list:
|
) -> list:
|
||||||
"""
|
"""
|
||||||
Creates (formal fl) sketch (fl proof sketch) for latter use in a formal proof sketch.
|
Creates (formal fl) sketch (fl proof sketch) for latter use in a formal proof sketch.
|
||||||
z_pred_fl ~ sketch(eng, x_nl_prob, y_pred_nl, x_fl_prob, P_sketch)
|
z_pred_fl ~ sketch(eng, x_nl_prob, y_pred_nl, x_fl_prob, P_sketch)
|
||||||
"""
|
"""
|
||||||
assert len(drafts) == 1, f"For now only 1 draft."
|
assert len(drafts) == 1, f"For now only 1 draft."
|
||||||
# Make prompt from template
|
# Make prompt from template
|
||||||
x_nl_problem: str = data_pt['nl_problem'][0]
|
x_nl_problem: str = data_pt['nl_problem'][0]
|
||||||
|
@ -133,86 +147,92 @@ def sketch(
|
||||||
top_p=eng.sketch_sampling_params.top_p,
|
top_p=eng.sketch_sampling_params.top_p,
|
||||||
n=eng.sketch_sampling_params.n,
|
n=eng.sketch_sampling_params.n,
|
||||||
# stop=eng.sketch_sampling_params.stop[:3],
|
# stop=eng.sketch_sampling_params.stop[:3],
|
||||||
)
|
)
|
||||||
# Get all completions for single prompt
|
# Get all completions for single prompt
|
||||||
completions: list[str] = [completion.message.content for completion in response.choices] # response.choices[i].message
|
completions: list[str] = [completion.message.content for completion in response.choices] # response.choices[i].message
|
||||||
sketches: list[str] = completions
|
sketches: list[str] = completions
|
||||||
# Return
|
# Return
|
||||||
return sketches, x_fl_problem
|
return sketches, x_fl_problem
|
||||||
|
|
||||||
def prove(
|
def prove(
|
||||||
eng,
|
eng,
|
||||||
fl_prob: str,
|
fl_prob: str,
|
||||||
fl_sketch: list[str],
|
fl_sketch: list[str],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Complete formal sketch and check if it proves the theorem.
|
|
||||||
|
|
||||||
fl_prob --> Lean4 theorem (problem)
|
Complete formal sketch and check if it proves the theorem.
|
||||||
|
|
||||||
|
fl_prob --> Lean4 theorem (problem)
|
||||||
fl_sketch --> Lean4 Form Sketch --> have x have ha
|
fl_sketch --> Lean4 Form Sketch --> have x have ha
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
print(f"fl_prob={fl_prob}")
|
||||||
|
print(f"fl_sketch={fl_sketch}")
|
||||||
|
raise RuntimeError("Not implemented")
|
||||||
# -- Prove
|
# -- Prove
|
||||||
correct: bool = False
|
correct: bool = False
|
||||||
# -- Return
|
# -- Return
|
||||||
return correct
|
return correct
|
||||||
|
|
||||||
# -- DSP for Lean
|
# -- DSP for Lean
|
||||||
|
|
||||||
def single_proof_search_dsp_lean(
|
def single_proof_search_dsp_lean(
|
||||||
eng: Engine,
|
eng: Engine,
|
||||||
|
server: Server,
|
||||||
data_pt: dict,
|
data_pt: dict,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
# -- Draft: [y_nl_pred_draft]_n ~ draft(eng, x_nl_prob, P_draft)
|
# -- Draft: [y_nl_pred_draft]_n ~ draft(eng, x_nl_prob, P_draft)
|
||||||
y_nl_pred_drafts = draft(eng, data_pt)
|
y_nl_pred_drafts = draft(eng, data_pt)
|
||||||
|
|
||||||
# -- Sketch: z_fl_pred_sketch ~ sketch(eng, x_nl_prob, [y_nl_pred_draft]_n, x_fl_prob, P_sketch)
|
# -- Sketch: z_fl_pred_sketch ~ sketch(eng, x_nl_prob, [y_nl_pred_draft]_n, x_fl_prob, P_sketch)
|
||||||
z_fl_pred_sketches, x_fl_prob = sketch(eng, data_pt, y_nl_pred_drafts)
|
z_fl_pred_sketches, x_fl_prob = sketch(eng, data_pt, y_nl_pred_drafts)
|
||||||
|
|
||||||
# -- Prove: y_fl = prove(eng, x_fl_prob, z_fl_pred_sketches)
|
# -- Prove: y_fl = prove(eng, x_fl_prob, z_fl_pred_sketches)
|
||||||
correct: bool = prove(eng, x_fl_prob, z_fl_pred_sketches)
|
correct: bool = prove(eng, x_fl_prob, z_fl_pred_sketches)
|
||||||
|
|
||||||
# -- Return
|
# -- Return
|
||||||
return
|
return correct
|
||||||
|
|
||||||
|
|
||||||
def full_proof_search_dsp_lean(
|
def full_proof_search_dsp_lean(
|
||||||
eng: Engine,
|
eng: Engine,
|
||||||
path_2_eval_dataset: Union[str, Path],
|
server: Server,
|
||||||
|
path_2_eval_dataset: Path,
|
||||||
):
|
):
|
||||||
# -- Get eval data
|
# -- Get eval data
|
||||||
path_2_eval_dataset = Path(path_2_eval_dataset).expanduser()
|
|
||||||
eval_dataset: list[dict] = json.load(open(path_2_eval_dataset, 'r'))
|
eval_dataset: list[dict] = json.load(open(path_2_eval_dataset, 'r'))
|
||||||
print(f'{len(eval_dataset)=}')
|
print(f'{len(eval_dataset)=}')
|
||||||
# -- Proof search by DSP over all eval data
|
# -- Proof search by DSP over all eval data
|
||||||
data_pt: dict
|
|
||||||
for data_pt in tqdm(eval_dataset, total=len(eval_dataset), desc='DSP proof loop per data point in benchmark.'):
|
for data_pt in tqdm(eval_dataset, total=len(eval_dataset), desc='DSP proof loop per data point in benchmark.'):
|
||||||
print(f'{data_pt=}')
|
print(f'{data_pt=}')
|
||||||
single_proof_search_dsp_lean(eng, data_pt)
|
flag = single_proof_search_dsp_lean(eng, server, data_pt)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
experiment_dir = Path(__file__).resolve().parent
|
||||||
|
|
||||||
# -- Main
|
# -- Main
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
path_2_eval_dataset: str = '~/PyPantograph/examples/lean4_dsp/debug/toy_example1_dsp/dsp_debug5_sf/dsp_debug5_sf_train.json',
|
path_2_eval_dataset: str = experiment_dir / 'debug/toy_example1_dsp/dsp_debug5_sf/dsp_debug5_sf_train.json',
|
||||||
# model: str = 'deepseek-ai/deepseek-math-7b-instruct',
|
# model: str = 'deepseek-ai/deepseek-math-7b-instruct',
|
||||||
# model: str = 'gpt2',
|
# model: str = 'gpt2',
|
||||||
# model: str = 'gpt-3.5-turbo',
|
# model: str = 'gpt-3.5-turbo',
|
||||||
model: str = 'gpt-4o',
|
model: str = 'gpt-4o',
|
||||||
start: int = 0,
|
start: int = 0,
|
||||||
end: int = sys.maxsize,
|
end: int = sys.maxsize,
|
||||||
# end: int = 10, # do 10 so enough boxed qs are there
|
# end: int = 10, # do 10 so enough boxed qs are there
|
||||||
batch_size: int = 10, # putnam has 348
|
batch_size: int = 10, # putnam has 348
|
||||||
n: int = 1, # num seqs to return for given prompt
|
n_samples: int = 1, # num seqs to return for given prompt
|
||||||
max_tokens: int = 2048,
|
max_tokens: int = 2048,
|
||||||
top_p: float = 0.95,
|
top_p: float = 0.95,
|
||||||
temperature: float = 0.8,
|
temperature: float = 0.8,
|
||||||
mode: str = "dryrun",
|
mode: str = "dryrun",
|
||||||
):
|
):
|
||||||
path_2_eval_dataset = Path(path_2_eval_dataset).expanduser()
|
path_2_eval_dataset = Path(path_2_eval_dataset).expanduser()
|
||||||
print(f'{path_2_eval_dataset=}')
|
print(f'{path_2_eval_dataset=}')
|
||||||
|
|
||||||
|
server = Server()
|
||||||
|
|
||||||
# - Start wandb run
|
# - Start wandb run
|
||||||
# print(f'\n\n-- Setup params')
|
# print(f'\n\n-- Setup params')
|
||||||
# CUDA_VISIBLE_DEVICES = os.environ.get("CUDA_VISIBLE_DEVICES")
|
# CUDA_VISIBLE_DEVICES = os.environ.get("CUDA_VISIBLE_DEVICES")
|
||||||
|
@ -226,18 +246,20 @@ def main(
|
||||||
# print(f'\n Config: \n{config=}')
|
# print(f'\n Config: \n{config=}')
|
||||||
|
|
||||||
# - Run DSP for Lean
|
# - Run DSP for Lean
|
||||||
if 'gpt-4-' in model or 'gpt-3.5-' in model or 'gpt-4o' in model:
|
api_key = os.environ['OPENAI_API_KEY']
|
||||||
api_key = os.environ['OPENAI_API_KEY']
|
draft_sampling_params = SamplingParams(n=n_samples, max_tokens=max_tokens, top_p=top_p, temperature=temperature, stop=STOP_TOKENS_DRAFT_V0)
|
||||||
SamplingParams = namedtuple('SamplingParams', ['n', 'max_tokens', 'top_p', 'temperature', 'stop'])
|
sketch_sampling_params = SamplingParams(n=n_samples, max_tokens=max_tokens, top_p=top_p, temperature=temperature, stop=STOP_TOKENS_SKETCH_V0)
|
||||||
draft_sampling_params = SamplingParams(n=n, max_tokens=max_tokens, top_p=top_p, temperature=temperature, stop=STOP_TOKENS_DRAFT_V0)
|
eng: OpenAI_DSP_Engine = OpenAI_DSP_Engine(
|
||||||
sketch_sampling_params = SamplingParams(n=n, max_tokens=max_tokens, top_p=top_p, temperature=temperature, stop=STOP_TOKENS_SKETCH_V0)
|
model=model,
|
||||||
eng: OpenAI_DSP_Engine = OpenAI_DSP_Engine(model=model, api_key=api_key, verbose_init=True, draft_sampling_params=draft_sampling_params, sketch_sampling_params=sketch_sampling_params)
|
api_key=api_key,
|
||||||
else:
|
verbose_init=True,
|
||||||
raise ValueError(f"Model {model=} not supported.")
|
draft_sampling_params=draft_sampling_params,
|
||||||
|
sketch_sampling_params=sketch_sampling_params,
|
||||||
|
)
|
||||||
|
|
||||||
# - Full proof search with DSP
|
# - Full proof search with DSP
|
||||||
print(f'\n\n-- Full proof search with DSP')
|
print(f'\n\n-- Full proof search with DSP')
|
||||||
full_proof_search_dsp_lean(eng, path_2_eval_dataset)
|
full_proof_search_dsp_lean(eng, server, path_2_eval_dataset)
|
||||||
|
|
||||||
# - End run
|
# - End run
|
||||||
# wandb.config.update(config)
|
# wandb.config.update(config)
|
||||||
|
|
Loading…
Reference in New Issue