2024-10-03 12:03:33 -07:00
import sys , os , json
2024-10-02 11:03:00 -07:00
from dataclasses import dataclass
2024-10-01 11:34:30 -07:00
from pathlib import Path
from typing import Union , Any
2024-09-26 20:20:32 -07:00
from collections import namedtuple
2024-07-11 15:49:37 -07:00
from tqdm import tqdm
2024-09-26 20:20:32 -07:00
from openai import OpenAI
2024-07-11 15:49:37 -07:00
import wandb
from tenacity import retry , stop_after_attempt , wait_exponential
2024-10-02 11:03:00 -07:00
from pantograph import Server
2024-07-11 15:49:37 -07:00
2024-10-02 16:10:52 -07:00
from solve . prompts import (
2024-10-03 12:26:42 -07:00
extract_lean_code ,
2024-10-02 16:10:52 -07:00
SYSTEM_PROMPT_DRAFT_V0 ,
SYSTEM_PROMPT_SKETCH_V0 ,
prompt_draft_template_lean4_v0 ,
prompt_sketch_template_lean4_v0 ,
STOP_TOKENS_DRAFT_V0 ,
STOP_TOKENS_SKETCH_V0 ,
get_prompt_sketch_template_4_lean_v0 ,
)
2024-07-11 15:49:37 -07:00
2024-09-26 20:20:32 -07:00
# prompt_draft_template_lean4_v0 = "Draft an informal solution similar to the one below. The informal solution will be used to sketch a formal proof in the Lean 4 Proof Assistant. Here are some examples of informal problem solutions pairs:\n\nInformal:\n(*### Problem\n\nProve that for any natural number n, n + 0 = n.\n\n### Solution\n\nConsider any natural number n. From properties of addition, adding zero does not change its values. Thus, n + 0 = n.*)\n\nInformal:\n(*### Problem\n\nProve that for any natural number n, n + (m + 1) = (n + m) + 1.\n\n### Solution\n\nConsider any natural numbers n and m. From properties of addition, adding 1 to the sum of n and m is the same as first adding m to n and then adding 1. Thus, n + (m + 1) = (n + m) + 1.*)\n\nInformal:\n(*### Problem\n\nProve that for any natural number n and m, n + m = m + n.\n\n### Solution\n\nConsider any natural numbers n and m. We will do induction on n. Base case: 0 + m = m + 0 by properties of addition. Inductive step, we have n + m = m + n. Then (n + 1) + m = (n + m) + 1 = (m + n) + 1 = m + (n + 1). Thus, by induction, n + m = m + n, qed.*)\n\nInformal: \n(*### Problem\n\n{nl_problem}\n\n### Solution\n"
2024-10-02 11:03:00 -07:00
@dataclass
class SamplingParams :
n : int
max_tokens : int
top_p : int
temperature : float
stop : str
2024-07-11 15:49:37 -07:00
class Engine :
def __init__ ( self ) :
pass
def __call__ ( self , * args , * * kwards ) :
pass
class OpenAI_DSP_Engine ( Engine ) :
def __init__ (
2024-10-02 11:03:00 -07:00
self ,
model : str ,
2024-07-11 15:49:37 -07:00
api_key : str = None ,
2024-10-02 11:03:00 -07:00
base_url : str = None , # e.g., Mistral-7B-Instrcut-v0.2 on http://120.77.8.29:12345
2024-07-11 15:49:37 -07:00
# Draft Params
draft_system_prompt : str = SYSTEM_PROMPT_DRAFT_V0 , # 'You are an expert mathematician and an expert in the Lean 4 Proof Assistant.' (goal do draft)
draft_prompt_template : str = prompt_draft_template_lean4_v0 ,
2024-09-26 20:20:32 -07:00
draft_sampling_params = None ,
2024-07-11 15:49:37 -07:00
draft_stop_tokens : list [ str ] = STOP_TOKENS_DRAFT_V0 ,
# Sketch Params
sketch_system_prompt : str = SYSTEM_PROMPT_SKETCH_V0 ,
sketch_prompt_template : str = prompt_sketch_template_lean4_v0 ,
2024-09-26 20:20:32 -07:00
sketch_sampling_params = None ,
2024-07-11 15:49:37 -07:00
sketch_stop_tokens : list [ str ] = STOP_TOKENS_SKETCH_V0 ,
# Prove Params
# ...TODO not sure if needed right now...
# Misc
verbose_init : bool = True ,
) :
super ( ) . __init__ ( )
2024-10-03 12:26:42 -07:00
print ( f ' { base_url =} ' ) if verbose_init else None
2024-10-02 11:03:00 -07:00
if not ( ' gpt-4- ' in model or ' gpt-3.5- ' in model or ' gpt-4o ' in model ) :
raise ValueError ( f " Model { model =} not supported. " )
2024-07-11 15:49:37 -07:00
self . model = model
self . api_key = api_key
2024-10-02 11:03:00 -07:00
self . llm = OpenAI ( api_key = self . api_key , base_url = base_url )
2024-07-11 15:49:37 -07:00
# Draft params
self . draft_system_prompt = draft_system_prompt
self . draft_prompt_template = draft_prompt_template
self . draft_sampling_params = draft_sampling_params
2024-09-26 20:20:32 -07:00
# self.draft_sampling_params.stop = draft_stop_tokens
2024-07-11 15:49:37 -07:00
# Sketch params
self . sketch_system_prompt = sketch_system_prompt
self . sketch_prompt_template = sketch_prompt_template
self . sketch_sampling_params = sketch_sampling_params
2024-09-26 20:20:32 -07:00
# self.sketch_sampling_params.stop = sketch_stop_tokens
2024-07-11 15:49:37 -07:00
# Prove params
# ...TODO not sure if needed right now...
@retry ( stop = stop_after_attempt ( 15 ) , wait = wait_exponential ( multiplier = 2 , max = 128 ) )
def autoformalize_prob (
2024-10-02 11:03:00 -07:00
eng ,
2024-07-11 15:49:37 -07:00
data_pt : dict ,
verbose : bool = False ,
) :
""" Autoformalize natural language problem to formal language problem. """
. . .
@retry ( stop = stop_after_attempt ( 15 ) , wait = wait_exponential ( multiplier = 2 , max = 128 ) )
def draft (
2024-10-02 11:03:00 -07:00
eng ,
data_pt : dict ,
2024-07-11 15:49:37 -07:00
verbose : bool = False ,
) - > list :
2024-10-02 11:03:00 -07:00
"""
Creates ( informal nl ) draft ( nl soln , nl proof sketch ) for latter use in a formal proof sketch .
y_pred_nl ~ draft ( eng , x_nl_prob , P_draft )
2024-07-11 15:49:37 -07:00
"""
# Make prompt from template
2024-09-26 20:20:32 -07:00
nl_problem : str = data_pt [ ' nl_problem ' ] [ 0 ]
2024-07-11 15:49:37 -07:00
prompt = eng . draft_prompt_template . replace ( ' {nl_problem} ' , nl_problem )
# Get all **completions** to single prompt, one (in) -> many (out)
# ref: https://platform.openai.com/docs/api-reference/chat/object
response : Any = eng . llm . chat . completions . create (
model = eng . model ,
messages = [
{ " role " : " system " , " content " : eng . draft_system_prompt } ,
{ " role " : " user " , " content " : prompt } ,
] ,
temperature = eng . draft_sampling_params . temperature ,
top_p = eng . draft_sampling_params . top_p ,
n = eng . draft_sampling_params . n ,
stop = eng . draft_sampling_params . stop [ : 3 ] ,
)
# Get all completions for single prompt
completions : list [ str ] = [ completion . message . content for completion in response . choices ] # response.choices[i].message
drafts : list [ str ] = completions
return drafts
2024-10-02 11:03:00 -07:00
2024-07-11 15:49:37 -07:00
@retry ( stop = stop_after_attempt ( 15 ) , wait = wait_exponential ( multiplier = 2 , max = 128 ) )
def sketch (
eng ,
data_pt : dict ,
drafts : list ,
autoformalize_prob_in_prompt : bool = False ,
verbose : bool = False ,
) - > list :
2024-10-02 11:03:00 -07:00
"""
Creates ( formal fl ) sketch ( fl proof sketch ) for latter use in a formal proof sketch .
z_pred_fl ~ sketch ( eng , x_nl_prob , y_pred_nl , x_fl_prob , P_sketch )
"""
2024-07-11 15:49:37 -07:00
assert len ( drafts ) == 1 , f " For now only 1 draft. "
# Make prompt from template
2024-09-26 20:20:32 -07:00
x_nl_problem : str = data_pt [ ' nl_problem ' ] [ 0 ]
2024-07-11 15:49:37 -07:00
y_nl_solution : str = drafts [ 0 ]
2024-10-01 11:34:30 -07:00
x_fl_problem = None
2024-07-11 15:49:37 -07:00
if autoformalize_prob_in_prompt :
2024-09-26 20:20:32 -07:00
# prompt = eng.sketch_prompt_template.replace('{nl_problem}', x_nl_problem).replace('{nl_solution}', y_nl_solution)
not NotImplemented
2024-07-11 15:49:37 -07:00
else :
2024-09-26 20:20:32 -07:00
x_fl_problem = data_pt [ ' fl_problem ' ] [ 0 ] if ' fl_problem ' in data_pt else autoformalize_prob ( eng , data_pt )
prompt = eng . sketch_prompt_template . replace ( ' {fl_problem} ' , x_nl_problem ) . replace ( ' {fl_problem} ' , y_nl_solution )
# Get all **completions** to single prompt, one (in) -> many (out), ref: https://platform.openai.com/docs/api-reference/chat/object
2024-07-11 15:49:37 -07:00
response : Any = eng . llm . chat . completions . create (
model = eng . model ,
messages = [
{ " role " : " system " , " content " : eng . sketch_system_prompt } ,
{ " role " : " user " , " content " : prompt } ,
] ,
temperature = eng . sketch_sampling_params . temperature ,
top_p = eng . sketch_sampling_params . top_p ,
n = eng . sketch_sampling_params . n ,
2024-09-26 20:20:32 -07:00
# stop=eng.sketch_sampling_params.stop[:3],
2024-10-02 11:03:00 -07:00
)
2024-07-11 15:49:37 -07:00
# Get all completions for single prompt
completions : list [ str ] = [ completion . message . content for completion in response . choices ] # response.choices[i].message
sketches : list [ str ] = completions
2024-10-02 11:03:00 -07:00
# Return
2024-07-11 15:49:37 -07:00
return sketches , x_fl_problem
def prove (
2024-10-03 12:26:42 -07:00
eng : Engine ,
server : Server ,
2024-10-02 11:03:00 -07:00
fl_prob : str ,
2024-07-11 15:49:37 -07:00
fl_sketch : list [ str ] ,
) :
2024-09-26 20:20:32 -07:00
"""
2024-10-02 11:03:00 -07:00
Complete formal sketch and check if it proves the theorem .
fl_prob - - > Lean4 theorem ( problem )
2024-09-26 20:20:32 -07:00
fl_sketch - - > Lean4 Form Sketch - - > have x have ha
2024-10-02 11:03:00 -07:00
2024-09-26 20:20:32 -07:00
"""
2024-10-03 12:26:42 -07:00
# If this throws index out of bound errors it means the source doesn't contain walled off Lean sections.
lean_code , = [ extract_lean_code ( sketch ) [ 0 ] for sketch in fl_sketch ]
state , = server . load_sorry ( lean_code )
print ( state )
2024-10-02 11:03:00 -07:00
raise RuntimeError ( " Not implemented " )
2024-07-11 15:49:37 -07:00
# -- Prove
correct : bool = False
# -- Return
return correct
2024-10-02 11:03:00 -07:00
2024-07-11 15:49:37 -07:00
# -- DSP for Lean
def single_proof_search_dsp_lean (
2024-10-02 11:03:00 -07:00
eng : Engine ,
server : Server ,
2024-09-26 20:20:32 -07:00
data_pt : dict ,
2024-07-11 15:49:37 -07:00
) - > bool :
# -- Draft: [y_nl_pred_draft]_n ~ draft(eng, x_nl_prob, P_draft)
y_nl_pred_drafts = draft ( eng , data_pt )
2024-10-02 11:03:00 -07:00
# -- Sketch: z_fl_pred_sketch ~ sketch(eng, x_nl_prob, [y_nl_pred_draft]_n, x_fl_prob, P_sketch)
2024-07-11 15:49:37 -07:00
z_fl_pred_sketches , x_fl_prob = sketch ( eng , data_pt , y_nl_pred_drafts )
# -- Prove: y_fl = prove(eng, x_fl_prob, z_fl_pred_sketches)
2024-10-03 12:26:42 -07:00
correct : bool = prove ( eng , server , x_fl_prob , z_fl_pred_sketches )
2024-07-11 15:49:37 -07:00
# -- Return
2024-10-02 11:03:00 -07:00
return correct
2024-07-11 15:49:37 -07:00
def full_proof_search_dsp_lean (
eng : Engine ,
2024-10-02 11:03:00 -07:00
server : Server ,
path_2_eval_dataset : Path ,
2024-07-11 15:49:37 -07:00
) :
# -- Get eval data
eval_dataset : list [ dict ] = json . load ( open ( path_2_eval_dataset , ' r ' ) )
print ( f ' { len ( eval_dataset ) =} ' )
# -- Proof search by DSP over all eval data
2024-09-26 20:20:32 -07:00
for data_pt in tqdm ( eval_dataset , total = len ( eval_dataset ) , desc = ' DSP proof loop per data point in benchmark. ' ) :
print ( f ' { data_pt =} ' )
2024-10-02 11:03:00 -07:00
flag = single_proof_search_dsp_lean ( eng , server , data_pt )
2024-10-03 12:26:42 -07:00
server . gc ( )
2024-07-11 15:49:37 -07:00
return
2024-10-02 11:03:00 -07:00
experiment_dir = Path ( __file__ ) . resolve ( ) . parent
2024-07-11 15:49:37 -07:00
# -- Main
2024-10-03 12:03:33 -07:00
def main ( args ) :
2024-10-02 16:10:52 -07:00
import time
start_time = time . time ( )
2024-10-03 12:03:33 -07:00
path_2_eval_dataset = Path ( args . eval_dataset ) . expanduser ( )
2024-07-11 15:49:37 -07:00
print ( f ' { path_2_eval_dataset =} ' )
2024-10-02 11:03:00 -07:00
server = Server ( )
2024-07-11 15:49:37 -07:00
# - Start wandb run
2024-09-26 20:20:32 -07:00
# print(f'\n\n-- Setup params')
# CUDA_VISIBLE_DEVICES = os.environ.get("CUDA_VISIBLE_DEVICES")
# current_tmux_session = os.environ.get("TMUX", "").split(",")[-1]
# today = datetime.datetime.now().strftime("%Y-m%m-d%d-t%Hh_%Mm_%Ss")
# config = {'today': today, "CUDA_VISIBLE_DEVICES": CUDA_VISIBLE_DEVICES, "current_tmux_session": current_tmux_session, "model": model, "path_2_eval_dataset": path_2_eval_dataset}
# project: str = 'pypantograph'
# run_name = f"{project}: ({config})"
# run = wandb.init(mode=mode, project=project, name=run_name, save_code=True, config=config)
# print(f"{run.url=}")
# print(f'\n Config: \n{config=}')
2024-07-11 15:49:37 -07:00
# - Run DSP for Lean
2024-10-02 11:03:00 -07:00
api_key = os . environ [ ' OPENAI_API_KEY ' ]
2024-10-03 12:03:33 -07:00
draft_sampling_params = SamplingParams (
n = args . n_samples ,
max_tokens = args . max_tokens ,
top_p = args . top_p ,
temperature = args . temperature ,
stop = STOP_TOKENS_DRAFT_V0 ,
)
sketch_sampling_params = SamplingParams (
n = args . n_samples ,
max_tokens = args . max_tokens ,
top_p = args . top_p ,
temperature = args . temperature ,
stop = STOP_TOKENS_SKETCH_V0 ,
)
2024-10-02 11:03:00 -07:00
eng : OpenAI_DSP_Engine = OpenAI_DSP_Engine (
2024-10-03 12:03:33 -07:00
model = args . model ,
2024-10-02 11:03:00 -07:00
api_key = api_key ,
verbose_init = True ,
draft_sampling_params = draft_sampling_params ,
sketch_sampling_params = sketch_sampling_params ,
)
2024-07-11 15:49:37 -07:00
# - Full proof search with DSP
2024-09-26 20:20:32 -07:00
print ( f ' \n \n -- Full proof search with DSP ' )
2024-10-02 11:03:00 -07:00
full_proof_search_dsp_lean ( eng , server , path_2_eval_dataset )
2024-10-02 16:10:52 -07:00
print ( f " Time taken: { time . time ( ) - start_time : .2f } seconds, or { ( time . time ( ) - start_time ) / 60 : .2f } minutes, or { ( time . time ( ) - start_time ) / 3600 : .2f } hours. \a " )
2024-07-11 15:49:37 -07:00
# - End run
2024-09-26 20:20:32 -07:00
# wandb.config.update(config)
# print(f"{wandb.config=}")
# run.finish()
2024-07-11 15:49:37 -07:00
if __name__ == " __main__ " :
2024-10-03 12:03:33 -07:00
import argparse
2024-10-02 16:10:52 -07:00
parser = argparse . ArgumentParser (
prog = ' DSP ' ,
2024-10-03 12:03:33 -07:00
description = " Draft-Sketch-Prove on Lean code " ,
formatter_class = argparse . ArgumentDefaultsHelpFormatter ,
2024-10-02 16:10:52 -07:00
)
parser . add_argument (
' mode ' ,
help = " Function " ,
choices = [ ' eval ' , ' prompts ' ] ,
)
parser . add_argument (
" --eval-dataset " ,
help = " Evaluation dataset path " ,
default = experiment_dir / ' debug/toy_example1_dsp/dsp_debug5_sf/dsp_debug5_sf_train.json ' ,
)
2024-10-03 12:03:33 -07:00
parser . add_argument ( " --model " , help = " Model " , default = " gpt-4o " , choices = [ " gpt2 " , " gpt-3.5-turbo " , " gpt-4o " , " deepseek-ai/deepseek-math-7b-instruct " ] )
2024-10-02 16:10:52 -07:00
parser . add_argument ( " --start " , default = 0 )
parser . add_argument ( " --end " , default = sys . maxsize )
2024-10-03 12:03:33 -07:00
parser . add_argument ( " --batchsize " , default = 10 , help = " putnam has 348 " )
parser . add_argument ( " --n-samples " , default = 1 , help = " num seqs to return for given prompt " )
parser . add_argument ( " --max-tokens " , default = 2048 , help = " Maximum number of tokens in one sample " )
parser . add_argument ( " --top-p " , default = 0.95 , help = " Sampling top p " )
parser . add_argument ( " --temperature " , default = 0.8 , help = " Sampling temperature " )
2024-10-02 16:10:52 -07:00
parser . add_argument ( " --verbose " , action = ' store_true ' )
args = parser . parse_args ( )
if args . mode == " eval " :
2024-10-03 12:03:33 -07:00
main ( args )
2024-10-02 16:10:52 -07:00
elif args . mode == " prompts " :
prompt = get_prompt_sketch_template_4_lean_v0 ( verbose = args . verbose )
print ( prompt )
else :
raise ValueError ( f " Unknown mode: { args . mode } " )