2024-10-01 11:34:30 -07:00
import sys , os , json
2024-10-02 11:03:00 -07:00
from dataclasses import dataclass
2024-10-01 11:34:30 -07:00
from pathlib import Path
from typing import Union , Any
2024-09-26 20:20:32 -07:00
from collections import namedtuple
2024-07-11 15:49:37 -07:00
import fire
from tqdm import tqdm
2024-09-26 20:20:32 -07:00
from openai import OpenAI
2024-07-11 15:49:37 -07:00
import wandb
from tenacity import retry , stop_after_attempt , wait_exponential
2024-10-02 11:03:00 -07:00
from pantograph import Server
2024-07-11 15:49:37 -07:00
from solve . dsp_lean_prompts import SYSTEM_PROMPT_DRAFT_V0 , prompt_draft_template_lean4_v0 , STOP_TOKENS_DRAFT_V0
from solve . dsp_lean_prompts import SYSTEM_PROMPT_SKETCH_V0 , prompt_sketch_template_lean4_v0 , STOP_TOKENS_SKETCH_V0
2024-09-26 20:20:32 -07:00
# prompt_draft_template_lean4_v0 = "Draft an informal solution similar to the one below. The informal solution will be used to sketch a formal proof in the Lean 4 Proof Assistant. Here are some examples of informal problem solutions pairs:\n\nInformal:\n(*### Problem\n\nProve that for any natural number n, n + 0 = n.\n\n### Solution\n\nConsider any natural number n. From properties of addition, adding zero does not change its values. Thus, n + 0 = n.*)\n\nInformal:\n(*### Problem\n\nProve that for any natural number n, n + (m + 1) = (n + m) + 1.\n\n### Solution\n\nConsider any natural numbers n and m. From properties of addition, adding 1 to the sum of n and m is the same as first adding m to n and then adding 1. Thus, n + (m + 1) = (n + m) + 1.*)\n\nInformal:\n(*### Problem\n\nProve that for any natural number n and m, n + m = m + n.\n\n### Solution\n\nConsider any natural numbers n and m. We will do induction on n. Base case: 0 + m = m + 0 by properties of addition. Inductive step, we have n + m = m + n. Then (n + 1) + m = (n + m) + 1 = (m + n) + 1 = m + (n + 1). Thus, by induction, n + m = m + n, qed.*)\n\nInformal: \n(*### Problem\n\n{nl_problem}\n\n### Solution\n"
2024-10-02 11:03:00 -07:00
@dataclass
class SamplingParams :
n : int
max_tokens : int
top_p : int
temperature : float
stop : str
2024-07-11 15:49:37 -07:00
class Engine :
def __init__ ( self ) :
pass
def __call__ ( self , * args , * * kwards ) :
pass
class OpenAI_DSP_Engine ( Engine ) :
def __init__ (
2024-10-02 11:03:00 -07:00
self ,
model : str ,
2024-07-11 15:49:37 -07:00
api_key : str = None ,
2024-10-02 11:03:00 -07:00
base_url : str = None , # e.g., Mistral-7B-Instrcut-v0.2 on http://120.77.8.29:12345
2024-07-11 15:49:37 -07:00
# Draft Params
draft_system_prompt : str = SYSTEM_PROMPT_DRAFT_V0 , # 'You are an expert mathematician and an expert in the Lean 4 Proof Assistant.' (goal do draft)
draft_prompt_template : str = prompt_draft_template_lean4_v0 ,
2024-09-26 20:20:32 -07:00
draft_sampling_params = None ,
2024-07-11 15:49:37 -07:00
draft_stop_tokens : list [ str ] = STOP_TOKENS_DRAFT_V0 ,
# Sketch Params
sketch_system_prompt : str = SYSTEM_PROMPT_SKETCH_V0 ,
sketch_prompt_template : str = prompt_sketch_template_lean4_v0 ,
2024-09-26 20:20:32 -07:00
sketch_sampling_params = None ,
2024-07-11 15:49:37 -07:00
sketch_stop_tokens : list [ str ] = STOP_TOKENS_SKETCH_V0 ,
# Prove Params
# ...TODO not sure if needed right now...
# Misc
verbose_init : bool = True ,
) :
super ( ) . __init__ ( )
print ( f ' { api_key =} , { base_url =} ' ) if verbose_init else None
2024-10-02 11:03:00 -07:00
if not ( ' gpt-4- ' in model or ' gpt-3.5- ' in model or ' gpt-4o ' in model ) :
raise ValueError ( f " Model { model =} not supported. " )
2024-07-11 15:49:37 -07:00
self . model = model
self . api_key = api_key
2024-10-02 11:03:00 -07:00
self . llm = OpenAI ( api_key = self . api_key , base_url = base_url )
2024-07-11 15:49:37 -07:00
# Draft params
self . draft_system_prompt = draft_system_prompt
self . draft_prompt_template = draft_prompt_template
self . draft_sampling_params = draft_sampling_params
2024-09-26 20:20:32 -07:00
# self.draft_sampling_params.stop = draft_stop_tokens
2024-07-11 15:49:37 -07:00
# Sketch params
self . sketch_system_prompt = sketch_system_prompt
self . sketch_prompt_template = sketch_prompt_template
self . sketch_sampling_params = sketch_sampling_params
2024-09-26 20:20:32 -07:00
# self.sketch_sampling_params.stop = sketch_stop_tokens
2024-07-11 15:49:37 -07:00
# Prove params
# ...TODO not sure if needed right now...
@retry ( stop = stop_after_attempt ( 15 ) , wait = wait_exponential ( multiplier = 2 , max = 128 ) )
def autoformalize_prob (
2024-10-02 11:03:00 -07:00
eng ,
2024-07-11 15:49:37 -07:00
data_pt : dict ,
verbose : bool = False ,
) :
""" Autoformalize natural language problem to formal language problem. """
. . .
@retry ( stop = stop_after_attempt ( 15 ) , wait = wait_exponential ( multiplier = 2 , max = 128 ) )
def draft (
2024-10-02 11:03:00 -07:00
eng ,
data_pt : dict ,
2024-07-11 15:49:37 -07:00
verbose : bool = False ,
) - > list :
2024-10-02 11:03:00 -07:00
"""
Creates ( informal nl ) draft ( nl soln , nl proof sketch ) for latter use in a formal proof sketch .
y_pred_nl ~ draft ( eng , x_nl_prob , P_draft )
2024-07-11 15:49:37 -07:00
"""
# Make prompt from template
2024-09-26 20:20:32 -07:00
nl_problem : str = data_pt [ ' nl_problem ' ] [ 0 ]
2024-07-11 15:49:37 -07:00
prompt = eng . draft_prompt_template . replace ( ' {nl_problem} ' , nl_problem )
# Get all **completions** to single prompt, one (in) -> many (out)
# ref: https://platform.openai.com/docs/api-reference/chat/object
response : Any = eng . llm . chat . completions . create (
model = eng . model ,
messages = [
{ " role " : " system " , " content " : eng . draft_system_prompt } ,
{ " role " : " user " , " content " : prompt } ,
] ,
temperature = eng . draft_sampling_params . temperature ,
top_p = eng . draft_sampling_params . top_p ,
n = eng . draft_sampling_params . n ,
stop = eng . draft_sampling_params . stop [ : 3 ] ,
)
# Get all completions for single prompt
completions : list [ str ] = [ completion . message . content for completion in response . choices ] # response.choices[i].message
drafts : list [ str ] = completions
return drafts
2024-10-02 11:03:00 -07:00
2024-07-11 15:49:37 -07:00
@retry ( stop = stop_after_attempt ( 15 ) , wait = wait_exponential ( multiplier = 2 , max = 128 ) )
def sketch (
eng ,
data_pt : dict ,
drafts : list ,
autoformalize_prob_in_prompt : bool = False ,
verbose : bool = False ,
) - > list :
2024-10-02 11:03:00 -07:00
"""
Creates ( formal fl ) sketch ( fl proof sketch ) for latter use in a formal proof sketch .
z_pred_fl ~ sketch ( eng , x_nl_prob , y_pred_nl , x_fl_prob , P_sketch )
"""
2024-07-11 15:49:37 -07:00
assert len ( drafts ) == 1 , f " For now only 1 draft. "
# Make prompt from template
2024-09-26 20:20:32 -07:00
x_nl_problem : str = data_pt [ ' nl_problem ' ] [ 0 ]
2024-07-11 15:49:37 -07:00
y_nl_solution : str = drafts [ 0 ]
2024-10-01 11:34:30 -07:00
x_fl_problem = None
2024-07-11 15:49:37 -07:00
if autoformalize_prob_in_prompt :
2024-09-26 20:20:32 -07:00
# prompt = eng.sketch_prompt_template.replace('{nl_problem}', x_nl_problem).replace('{nl_solution}', y_nl_solution)
not NotImplemented
2024-07-11 15:49:37 -07:00
else :
2024-09-26 20:20:32 -07:00
x_fl_problem = data_pt [ ' fl_problem ' ] [ 0 ] if ' fl_problem ' in data_pt else autoformalize_prob ( eng , data_pt )
prompt = eng . sketch_prompt_template . replace ( ' {fl_problem} ' , x_nl_problem ) . replace ( ' {fl_problem} ' , y_nl_solution )
# Get all **completions** to single prompt, one (in) -> many (out), ref: https://platform.openai.com/docs/api-reference/chat/object
2024-07-11 15:49:37 -07:00
response : Any = eng . llm . chat . completions . create (
model = eng . model ,
messages = [
{ " role " : " system " , " content " : eng . sketch_system_prompt } ,
{ " role " : " user " , " content " : prompt } ,
] ,
temperature = eng . sketch_sampling_params . temperature ,
top_p = eng . sketch_sampling_params . top_p ,
n = eng . sketch_sampling_params . n ,
2024-09-26 20:20:32 -07:00
# stop=eng.sketch_sampling_params.stop[:3],
2024-10-02 11:03:00 -07:00
)
2024-07-11 15:49:37 -07:00
# Get all completions for single prompt
completions : list [ str ] = [ completion . message . content for completion in response . choices ] # response.choices[i].message
sketches : list [ str ] = completions
2024-10-02 11:03:00 -07:00
# Return
2024-07-11 15:49:37 -07:00
return sketches , x_fl_problem
def prove (
2024-10-02 11:03:00 -07:00
eng ,
fl_prob : str ,
2024-07-11 15:49:37 -07:00
fl_sketch : list [ str ] ,
) :
2024-09-26 20:20:32 -07:00
"""
2024-10-02 11:03:00 -07:00
Complete formal sketch and check if it proves the theorem .
fl_prob - - > Lean4 theorem ( problem )
2024-09-26 20:20:32 -07:00
fl_sketch - - > Lean4 Form Sketch - - > have x have ha
2024-10-02 11:03:00 -07:00
2024-09-26 20:20:32 -07:00
"""
2024-10-02 11:03:00 -07:00
print ( f " fl_prob= { fl_prob } " )
print ( f " fl_sketch= { fl_sketch } " )
raise RuntimeError ( " Not implemented " )
2024-07-11 15:49:37 -07:00
# -- Prove
correct : bool = False
# -- Return
return correct
2024-10-02 11:03:00 -07:00
2024-07-11 15:49:37 -07:00
# -- DSP for Lean
def single_proof_search_dsp_lean (
2024-10-02 11:03:00 -07:00
eng : Engine ,
server : Server ,
2024-09-26 20:20:32 -07:00
data_pt : dict ,
2024-07-11 15:49:37 -07:00
) - > bool :
# -- Draft: [y_nl_pred_draft]_n ~ draft(eng, x_nl_prob, P_draft)
y_nl_pred_drafts = draft ( eng , data_pt )
2024-10-02 11:03:00 -07:00
# -- Sketch: z_fl_pred_sketch ~ sketch(eng, x_nl_prob, [y_nl_pred_draft]_n, x_fl_prob, P_sketch)
2024-07-11 15:49:37 -07:00
z_fl_pred_sketches , x_fl_prob = sketch ( eng , data_pt , y_nl_pred_drafts )
# -- Prove: y_fl = prove(eng, x_fl_prob, z_fl_pred_sketches)
correct : bool = prove ( eng , x_fl_prob , z_fl_pred_sketches )
# -- Return
2024-10-02 11:03:00 -07:00
return correct
2024-07-11 15:49:37 -07:00
def full_proof_search_dsp_lean (
eng : Engine ,
2024-10-02 11:03:00 -07:00
server : Server ,
path_2_eval_dataset : Path ,
2024-07-11 15:49:37 -07:00
) :
# -- Get eval data
eval_dataset : list [ dict ] = json . load ( open ( path_2_eval_dataset , ' r ' ) )
print ( f ' { len ( eval_dataset ) =} ' )
# -- Proof search by DSP over all eval data
2024-09-26 20:20:32 -07:00
for data_pt in tqdm ( eval_dataset , total = len ( eval_dataset ) , desc = ' DSP proof loop per data point in benchmark. ' ) :
print ( f ' { data_pt =} ' )
2024-10-02 11:03:00 -07:00
flag = single_proof_search_dsp_lean ( eng , server , data_pt )
2024-07-11 15:49:37 -07:00
return
2024-10-02 11:03:00 -07:00
experiment_dir = Path ( __file__ ) . resolve ( ) . parent
2024-07-11 15:49:37 -07:00
# -- Main
def main (
2024-10-02 11:03:00 -07:00
path_2_eval_dataset : str = experiment_dir / ' debug/toy_example1_dsp/dsp_debug5_sf/dsp_debug5_sf_train.json ' ,
2024-07-11 15:49:37 -07:00
# model: str = 'deepseek-ai/deepseek-math-7b-instruct',
# model: str = 'gpt2',
2024-09-26 20:20:32 -07:00
# model: str = 'gpt-3.5-turbo',
model : str = ' gpt-4o ' ,
2024-10-02 11:03:00 -07:00
start : int = 0 ,
end : int = sys . maxsize ,
# end: int = 10, # do 10 so enough boxed qs are there
batch_size : int = 10 , # putnam has 348
n_samples : int = 1 , # num seqs to return for given prompt
2024-07-11 15:49:37 -07:00
max_tokens : int = 2048 ,
2024-10-02 11:03:00 -07:00
top_p : float = 0.95 ,
2024-07-11 15:49:37 -07:00
temperature : float = 0.8 ,
mode : str = " dryrun " ,
) :
path_2_eval_dataset = Path ( path_2_eval_dataset ) . expanduser ( )
print ( f ' { path_2_eval_dataset =} ' )
2024-10-02 11:03:00 -07:00
server = Server ( )
2024-07-11 15:49:37 -07:00
# - Start wandb run
2024-09-26 20:20:32 -07:00
# print(f'\n\n-- Setup params')
# CUDA_VISIBLE_DEVICES = os.environ.get("CUDA_VISIBLE_DEVICES")
# current_tmux_session = os.environ.get("TMUX", "").split(",")[-1]
# today = datetime.datetime.now().strftime("%Y-m%m-d%d-t%Hh_%Mm_%Ss")
# config = {'today': today, "CUDA_VISIBLE_DEVICES": CUDA_VISIBLE_DEVICES, "current_tmux_session": current_tmux_session, "model": model, "path_2_eval_dataset": path_2_eval_dataset}
# project: str = 'pypantograph'
# run_name = f"{project}: ({config})"
# run = wandb.init(mode=mode, project=project, name=run_name, save_code=True, config=config)
# print(f"{run.url=}")
# print(f'\n Config: \n{config=}')
2024-07-11 15:49:37 -07:00
# - Run DSP for Lean
2024-10-02 11:03:00 -07:00
api_key = os . environ [ ' OPENAI_API_KEY ' ]
draft_sampling_params = SamplingParams ( n = n_samples , max_tokens = max_tokens , top_p = top_p , temperature = temperature , stop = STOP_TOKENS_DRAFT_V0 )
sketch_sampling_params = SamplingParams ( n = n_samples , max_tokens = max_tokens , top_p = top_p , temperature = temperature , stop = STOP_TOKENS_SKETCH_V0 )
eng : OpenAI_DSP_Engine = OpenAI_DSP_Engine (
model = model ,
api_key = api_key ,
verbose_init = True ,
draft_sampling_params = draft_sampling_params ,
sketch_sampling_params = sketch_sampling_params ,
)
2024-07-11 15:49:37 -07:00
# - Full proof search with DSP
2024-09-26 20:20:32 -07:00
print ( f ' \n \n -- Full proof search with DSP ' )
2024-10-02 11:03:00 -07:00
full_proof_search_dsp_lean ( eng , server , path_2_eval_dataset )
2024-07-11 15:49:37 -07:00
# - End run
2024-09-26 20:20:32 -07:00
# wandb.config.update(config)
# print(f"{wandb.config=}")
# run.finish()
2024-07-11 15:49:37 -07:00
if __name__ == " __main__ " :
import time
start_time = time . time ( )
fire . Fire ( main )
2024-10-01 11:34:30 -07:00
print ( f " Time taken: { time . time ( ) - start_time : .2f } seconds, or { ( time . time ( ) - start_time ) / 60 : .2f } minutes, or { ( time . time ( ) - start_time ) / 3600 : .2f } hours. \a " )