2024-10-04 18:53:00 -07:00
import sys , os , json , subprocess
2024-10-06 19:14:38 -07:00
from dataclasses import dataclass , asdict , field
2024-10-01 11:34:30 -07:00
from pathlib import Path
2024-10-06 19:14:38 -07:00
from typing import Union , Any , Tuple , Optional
2024-07-11 15:49:37 -07:00
from tqdm import tqdm
2024-09-26 20:20:32 -07:00
from openai import OpenAI
2024-07-11 15:49:37 -07:00
import wandb
from tenacity import retry , stop_after_attempt , wait_exponential
2024-10-06 19:14:38 -07:00
from pantograph import Server , ServerError
from pantograph . search import SearchResult
2024-10-04 18:36:52 -07:00
from termcolor import colored
2024-07-11 15:49:37 -07:00
2024-10-02 16:10:52 -07:00
from solve . prompts import (
2024-10-03 12:26:42 -07:00
extract_lean_code ,
2024-10-02 16:10:52 -07:00
SYSTEM_PROMPT_DRAFT_V0 ,
SYSTEM_PROMPT_SKETCH_V0 ,
prompt_draft_template_lean4_v0 ,
prompt_sketch_template_lean4_v0 ,
STOP_TOKENS_DRAFT_V0 ,
STOP_TOKENS_SKETCH_V0 ,
get_prompt_sketch_template_4_lean_v0 ,
)
2024-10-04 18:36:52 -07:00
from solve . prove import HammerAgent
2024-10-04 21:55:47 -07:00
from solve . data import Datum
2024-07-11 15:49:37 -07:00
2024-09-26 20:20:32 -07:00
# prompt_draft_template_lean4_v0 = "Draft an informal solution similar to the one below. The informal solution will be used to sketch a formal proof in the Lean 4 Proof Assistant. Here are some examples of informal problem solutions pairs:\n\nInformal:\n(*### Problem\n\nProve that for any natural number n, n + 0 = n.\n\n### Solution\n\nConsider any natural number n. From properties of addition, adding zero does not change its values. Thus, n + 0 = n.*)\n\nInformal:\n(*### Problem\n\nProve that for any natural number n, n + (m + 1) = (n + m) + 1.\n\n### Solution\n\nConsider any natural numbers n and m. From properties of addition, adding 1 to the sum of n and m is the same as first adding m to n and then adding 1. Thus, n + (m + 1) = (n + m) + 1.*)\n\nInformal:\n(*### Problem\n\nProve that for any natural number n and m, n + m = m + n.\n\n### Solution\n\nConsider any natural numbers n and m. We will do induction on n. Base case: 0 + m = m + 0 by properties of addition. Inductive step, we have n + m = m + n. Then (n + 1) + m = (n + m) + 1 = (m + n) + 1 = m + (n + 1). Thus, by induction, n + m = m + n, qed.*)\n\nInformal: \n(*### Problem\n\n{nl_problem}\n\n### Solution\n"
2024-10-02 11:03:00 -07:00
@dataclass
class SamplingParams :
n : int
max_tokens : int
top_p : int
temperature : float
stop : str
2024-10-06 19:14:38 -07:00
@dataclass ( frozen = True )
class SketchParseFailure :
error : str
sketch : str
@dataclass ( frozen = True )
class DatumResult :
"""
Result from one DSP data point
"""
name : str
success : Optional [ bool ] = False
proves : list [ Union [ SearchResult , SketchParseFailure ] ] = field ( default_factory = list )
2024-07-11 15:49:37 -07:00
class Engine :
def __init__ ( self ) :
pass
def __call__ ( self , * args , * * kwards ) :
pass
class OpenAI_DSP_Engine ( Engine ) :
def __init__ (
2024-10-04 21:55:47 -07:00
self ,
model : str ,
api_key : str = None ,
base_url : str = None , # e.g., Mistral-7B-Instrcut-v0.2 on http://120.77.8.29:12345
# Draft Params
draft_system_prompt : str = SYSTEM_PROMPT_DRAFT_V0 , # 'You are an expert mathematician and an expert in the Lean 4 Proof Assistant.' (goal do draft)
draft_prompt_template : str = prompt_draft_template_lean4_v0 ,
draft_sampling_params = None ,
draft_stop_tokens : list [ str ] = STOP_TOKENS_DRAFT_V0 ,
# Sketch Params
sketch_system_prompt : str = SYSTEM_PROMPT_SKETCH_V0 ,
sketch_prompt_template : str = prompt_sketch_template_lean4_v0 ,
sketch_sampling_params = None ,
sketch_stop_tokens : list [ str ] = STOP_TOKENS_SKETCH_V0 ,
# Prove Params
# ...TODO not sure if needed right now...
# Misc
verbose_init : bool = True ,
) :
2024-07-11 15:49:37 -07:00
super ( ) . __init__ ( )
2024-10-03 12:26:42 -07:00
print ( f ' { base_url =} ' ) if verbose_init else None
2024-10-02 11:03:00 -07:00
if not ( ' gpt-4- ' in model or ' gpt-3.5- ' in model or ' gpt-4o ' in model ) :
raise ValueError ( f " Model { model =} not supported. " )
2024-07-11 15:49:37 -07:00
self . model = model
self . api_key = api_key
2024-10-02 11:03:00 -07:00
self . llm = OpenAI ( api_key = self . api_key , base_url = base_url )
2024-07-11 15:49:37 -07:00
# Draft params
self . draft_system_prompt = draft_system_prompt
self . draft_prompt_template = draft_prompt_template
self . draft_sampling_params = draft_sampling_params
2024-09-26 20:20:32 -07:00
# self.draft_sampling_params.stop = draft_stop_tokens
2024-07-11 15:49:37 -07:00
# Sketch params
self . sketch_system_prompt = sketch_system_prompt
self . sketch_prompt_template = sketch_prompt_template
self . sketch_sampling_params = sketch_sampling_params
2024-09-26 20:20:32 -07:00
# self.sketch_sampling_params.stop = sketch_stop_tokens
2024-07-11 15:49:37 -07:00
# Prove params
# ...TODO not sure if needed right now...
@retry ( stop = stop_after_attempt ( 15 ) , wait = wait_exponential ( multiplier = 2 , max = 128 ) )
def autoformalize_prob (
2024-10-04 21:55:47 -07:00
eng : Engine ,
datum : Datum ,
verbose : bool = False ,
) :
2024-07-11 15:49:37 -07:00
""" Autoformalize natural language problem to formal language problem. """
2024-10-04 21:55:47 -07:00
pass
2024-07-11 15:49:37 -07:00
@retry ( stop = stop_after_attempt ( 15 ) , wait = wait_exponential ( multiplier = 2 , max = 128 ) )
def draft (
2024-10-04 21:55:47 -07:00
eng : Engine ,
datum : Datum ,
verbose : bool = False ,
2024-07-11 15:49:37 -07:00
) - > list :
2024-10-02 11:03:00 -07:00
"""
Creates ( informal nl ) draft ( nl soln , nl proof sketch ) for latter use in a formal proof sketch .
y_pred_nl ~ draft ( eng , x_nl_prob , P_draft )
2024-07-11 15:49:37 -07:00
"""
# Make prompt from template
2024-10-04 21:55:47 -07:00
nl_problem : str = datum . nl_problem_str
2024-07-11 15:49:37 -07:00
prompt = eng . draft_prompt_template . replace ( ' {nl_problem} ' , nl_problem )
# Get all **completions** to single prompt, one (in) -> many (out)
# ref: https://platform.openai.com/docs/api-reference/chat/object
response : Any = eng . llm . chat . completions . create (
model = eng . model ,
messages = [
{ " role " : " system " , " content " : eng . draft_system_prompt } ,
{ " role " : " user " , " content " : prompt } ,
] ,
temperature = eng . draft_sampling_params . temperature ,
top_p = eng . draft_sampling_params . top_p ,
n = eng . draft_sampling_params . n ,
stop = eng . draft_sampling_params . stop [ : 3 ] ,
)
# Get all completions for single prompt
2024-10-04 21:55:47 -07:00
completions : list [ str ] = [
completion . message . content
for completion in response . choices
] # response.choices[i].message
2024-07-11 15:49:37 -07:00
drafts : list [ str ] = completions
return drafts
2024-10-02 11:03:00 -07:00
2024-07-11 15:49:37 -07:00
@retry ( stop = stop_after_attempt ( 15 ) , wait = wait_exponential ( multiplier = 2 , max = 128 ) )
def sketch (
2024-10-04 21:55:47 -07:00
eng : Engine ,
datum : Datum ,
drafts : list [ str ] ,
2024-10-04 18:36:52 -07:00
autoformalize_prob_in_prompt : bool = False ,
verbose : bool = False ,
) - > Tuple [ list [ str ] , str ] :
2024-10-02 11:03:00 -07:00
"""
Creates ( formal fl ) sketch ( fl proof sketch ) for latter use in a formal proof sketch .
z_pred_fl ~ sketch ( eng , x_nl_prob , y_pred_nl , x_fl_prob , P_sketch )
"""
2024-07-11 15:49:37 -07:00
assert len ( drafts ) == 1 , f " For now only 1 draft. "
# Make prompt from template
2024-10-04 21:55:47 -07:00
x_nl_problem : str = datum . nl_problem_str
2024-07-11 15:49:37 -07:00
y_nl_solution : str = drafts [ 0 ]
2024-10-01 11:34:30 -07:00
x_fl_problem = None
2024-07-11 15:49:37 -07:00
if autoformalize_prob_in_prompt :
2024-09-26 20:20:32 -07:00
# prompt = eng.sketch_prompt_template.replace('{nl_problem}', x_nl_problem).replace('{nl_solution}', y_nl_solution)
not NotImplemented
2024-07-11 15:49:37 -07:00
else :
2024-10-04 21:55:47 -07:00
x_fl_problem = datum . fl_problem if datum . fl_problem else autoformalize_prob ( eng , datum )
2024-09-26 20:20:32 -07:00
prompt = eng . sketch_prompt_template . replace ( ' {fl_problem} ' , x_nl_problem ) . replace ( ' {fl_problem} ' , y_nl_solution )
# Get all **completions** to single prompt, one (in) -> many (out), ref: https://platform.openai.com/docs/api-reference/chat/object
2024-07-11 15:49:37 -07:00
response : Any = eng . llm . chat . completions . create (
model = eng . model ,
messages = [
{ " role " : " system " , " content " : eng . sketch_system_prompt } ,
{ " role " : " user " , " content " : prompt } ,
] ,
temperature = eng . sketch_sampling_params . temperature ,
top_p = eng . sketch_sampling_params . top_p ,
n = eng . sketch_sampling_params . n ,
2024-09-26 20:20:32 -07:00
# stop=eng.sketch_sampling_params.stop[:3],
2024-10-02 11:03:00 -07:00
)
2024-07-11 15:49:37 -07:00
# Get all completions for single prompt
completions : list [ str ] = [ completion . message . content for completion in response . choices ] # response.choices[i].message
sketches : list [ str ] = completions
2024-10-02 11:03:00 -07:00
# Return
2024-07-11 15:49:37 -07:00
return sketches , x_fl_problem
def prove (
2024-10-04 21:55:47 -07:00
eng : Engine ,
server : Server ,
fl_prob : str ,
2024-10-06 19:14:38 -07:00
fl_sketch : str ,
) - > list [ SearchResult ] :
2024-09-26 20:20:32 -07:00
"""
2024-10-02 11:03:00 -07:00
Complete formal sketch and check if it proves the theorem .
fl_prob - - > Lean4 theorem ( problem )
2024-09-26 20:20:32 -07:00
fl_sketch - - > Lean4 Form Sketch - - > have x have ha
2024-10-02 11:03:00 -07:00
2024-09-26 20:20:32 -07:00
"""
2024-10-03 12:26:42 -07:00
# If this throws index out of bound errors it means the source doesn't contain walled off Lean sections.
2024-10-04 21:55:47 -07:00
print ( colored ( " Sketch: " , " yellow " ) , fl_sketch )
2024-10-06 19:14:38 -07:00
lean_code = " \n " . join ( extract_lean_code ( fl_sketch ) )
2024-10-05 15:38:35 -07:00
print ( colored ( " Lean code: " , " light_grey " ) , lean_code )
2024-10-06 19:14:38 -07:00
try :
states = server . load_sorry ( lean_code )
except ServerError as e :
msg = f " Encountered exception: { e } "
print ( colored ( msg , " red " ) )
return SketchParseFailure (
sketch = fl_sketch ,
error = msg ,
)
2024-10-05 15:38:35 -07:00
if len ( states ) != 1 :
print ( colored ( " Model must output one compilation unit " , " red " ) )
2024-10-06 19:14:38 -07:00
return SketchParseFailure (
sketch = fl_sketch ,
error = " Model must output one compilation unit " ,
)
2024-10-05 15:38:35 -07:00
state = states [ 0 ]
if isinstance ( state , list ) and len ( state ) > 0 :
2024-10-06 19:14:38 -07:00
# This means `state` contains error messages
msg = " \n " . join ( state )
print ( colored ( " Sketch failed: " , " red " ) , msg )
return SketchParseFailure (
sketch = fl_sketch ,
error = f " Sketch failed: { msg } " ,
)
2024-10-05 15:38:35 -07:00
2024-10-04 18:36:52 -07:00
agent = HammerAgent ( )
2024-10-06 19:14:38 -07:00
result = agent . search (
server ,
state ,
max_steps = 1000 ,
max_trials_per_goal = len ( agent . tactics ) + 1 ,
)
2024-10-04 18:36:52 -07:00
print ( colored ( f " Result: { result } " , " blue " ) )
2024-10-03 12:26:42 -07:00
2024-10-05 00:59:28 -07:00
return result
2024-10-02 11:03:00 -07:00
2024-07-11 15:49:37 -07:00
# -- DSP for Lean
def single_proof_search_dsp_lean (
2024-10-02 11:03:00 -07:00
eng : Engine ,
2024-10-06 19:14:38 -07:00
server_func ,
2024-10-04 21:55:47 -07:00
datum : Datum ,
2024-10-06 19:14:38 -07:00
) - > DatumResult :
2024-07-11 15:49:37 -07:00
# -- Draft: [y_nl_pred_draft]_n ~ draft(eng, x_nl_prob, P_draft)
2024-10-04 21:55:47 -07:00
y_nl_pred_drafts = draft ( eng , datum )
2024-07-11 15:49:37 -07:00
2024-10-02 11:03:00 -07:00
# -- Sketch: z_fl_pred_sketch ~ sketch(eng, x_nl_prob, [y_nl_pred_draft]_n, x_fl_prob, P_sketch)
2024-10-04 21:55:47 -07:00
z_fl_pred_sketches , x_fl_prob = sketch ( eng , datum , y_nl_pred_drafts )
2024-07-11 15:49:37 -07:00
2024-10-06 19:14:38 -07:00
assert len ( z_fl_pred_sketches ) == 1
2024-07-11 15:49:37 -07:00
2024-10-06 19:14:38 -07:00
server = server_func ( )
# -- Prove: y_fl = prove(eng, x_fl_prob, z_fl_pred_sketches)
prove_result = [ prove ( eng , server , x_fl_prob , sketch ) for sketch in z_fl_pred_sketches ]
return DatumResult (
name = str ( datum ) ,
success = any (
x . success for x in prove_result
if isinstance ( x , SearchResult )
) ,
proves = prove_result ,
)
2024-07-11 15:49:37 -07:00
def full_proof_search_dsp_lean (
2024-10-04 21:55:47 -07:00
eng : Engine ,
2024-10-06 19:14:38 -07:00
server_func ,
2024-10-04 21:55:47 -07:00
data : list [ Datum ] ,
2024-10-05 01:23:02 -07:00
path_output : Path ,
2024-10-04 21:55:47 -07:00
) :
print ( colored ( f " DSP on { len ( data ) } points " , " blue " , attrs = [ " bold " , " underline " ] ) )
2024-10-06 19:14:38 -07:00
n_success = 0
2024-10-06 23:29:14 -07:00
n_tried = 0
2024-07-11 15:49:37 -07:00
# -- Proof search by DSP over all eval data
2024-10-05 01:23:02 -07:00
for i , datum in tqdm ( enumerate ( data ) , total = len ( data ) , desc = ' DSP proof loop per data point in benchmark. ' ) :
file_name = path_output / f " { i : 03 } .json "
2024-10-06 19:14:38 -07:00
key = str ( datum )
# Detect if file exists
if file_name . is_file ( ) :
obj = json . load ( open ( file_name , " r " ) )
if obj [ ' name ' ] != key :
print ( colored ( f " Existing datum name { obj [ ' name ' ] } does not match dataset { key } . The output directory may be wrong " ) )
2024-10-06 23:29:14 -07:00
return
2024-10-06 19:14:38 -07:00
print ( f " Skipped { i } : " , colored ( key , " green " ) )
continue
2024-10-06 23:29:14 -07:00
n_tried + = 1
2024-10-06 19:14:38 -07:00
print ( f " Problem { i } : " , colored ( key , " cyan " ) )
result = single_proof_search_dsp_lean ( eng , server_func , datum )
2024-10-05 01:23:02 -07:00
with open ( file_name , ' w ' ) as f :
2024-10-06 19:14:38 -07:00
json . dump ( asdict ( result ) , f )
if result . success :
n_success + = 1
2024-10-04 21:55:47 -07:00
#server.gc()
2024-10-06 23:29:14 -07:00
print ( f " Proved { n_success } / { n_tried } problems " )
2024-07-11 15:49:37 -07:00
2024-10-04 18:53:00 -07:00
2024-10-02 11:03:00 -07:00
experiment_dir = Path ( __file__ ) . resolve ( ) . parent
2024-10-04 18:53:00 -07:00
def get_project_and_lean_path ( ) :
cwd = experiment_dir / ' lean_src_proj '
p = subprocess . check_output ( [ ' lake ' , ' env ' , ' printenv ' , ' LEAN_PATH ' ] , cwd = cwd )
return cwd , p
2024-10-04 21:55:47 -07:00
def load_data ( args ) - > list [ Datum ] :
p = Path ( args . dataset ) . expanduser ( )
data = None
if p . suffix == " .json " :
data = [
Datum . load ( obj , data_format = args . format )
for obj in json . load ( open ( p , ' r ' ) )
]
elif p . suffix == " .jsonl " :
with open ( p , ' r ' ) as f :
data = [
Datum . load ( json . loads ( line ) , data_format = args . format )
for line in list ( f )
]
else :
raise ValueError ( f " Unknown data suffix: { p . suffix } " )
data = [ datum for datum in data if datum ]
return data
2024-07-11 15:49:37 -07:00
# -- Main
2024-10-03 12:03:33 -07:00
def main ( args ) :
2024-10-05 01:23:02 -07:00
import time , datetime
2024-10-02 16:10:52 -07:00
start_time = time . time ( )
2024-10-05 01:23:02 -07:00
# Setup paths and data
2024-10-04 21:55:47 -07:00
data_eval = load_data ( args )
2024-10-05 01:23:02 -07:00
path_output = Path ( args . output )
path_output . mkdir ( exist_ok = True , parents = True )
2024-07-11 15:49:37 -07:00
2024-10-04 21:55:47 -07:00
# Start server
2024-10-04 18:53:00 -07:00
project_path , lean_path = get_project_and_lean_path ( )
2024-10-06 19:14:38 -07:00
def server_func ( ) :
return Server (
imports = [ " Mathlib " , " Aesop " ] ,
project_path = project_path ,
lean_path = lean_path ,
)
2024-10-02 11:03:00 -07:00
2024-07-11 15:49:37 -07:00
# - Start wandb run
2024-09-26 20:20:32 -07:00
# print(f'\n\n-- Setup params')
# CUDA_VISIBLE_DEVICES = os.environ.get("CUDA_VISIBLE_DEVICES")
# current_tmux_session = os.environ.get("TMUX", "").split(",")[-1]
# today = datetime.datetime.now().strftime("%Y-m%m-d%d-t%Hh_%Mm_%Ss")
# config = {'today': today, "CUDA_VISIBLE_DEVICES": CUDA_VISIBLE_DEVICES, "current_tmux_session": current_tmux_session, "model": model, "path_2_eval_dataset": path_2_eval_dataset}
# project: str = 'pypantograph'
# run_name = f"{project}: ({config})"
# run = wandb.init(mode=mode, project=project, name=run_name, save_code=True, config=config)
# print(f"{run.url=}")
# print(f'\n Config: \n{config=}')
2024-07-11 15:49:37 -07:00
# - Run DSP for Lean
2024-10-02 11:03:00 -07:00
api_key = os . environ [ ' OPENAI_API_KEY ' ]
2024-10-03 12:03:33 -07:00
draft_sampling_params = SamplingParams (
n = args . n_samples ,
max_tokens = args . max_tokens ,
top_p = args . top_p ,
temperature = args . temperature ,
stop = STOP_TOKENS_DRAFT_V0 ,
)
sketch_sampling_params = SamplingParams (
n = args . n_samples ,
max_tokens = args . max_tokens ,
top_p = args . top_p ,
temperature = args . temperature ,
stop = STOP_TOKENS_SKETCH_V0 ,
)
2024-10-02 11:03:00 -07:00
eng : OpenAI_DSP_Engine = OpenAI_DSP_Engine (
2024-10-03 12:03:33 -07:00
model = args . model ,
2024-10-02 11:03:00 -07:00
api_key = api_key ,
verbose_init = True ,
draft_sampling_params = draft_sampling_params ,
sketch_sampling_params = sketch_sampling_params ,
)
2024-07-11 15:49:37 -07:00
# - Full proof search with DSP
2024-10-06 19:14:38 -07:00
full_proof_search_dsp_lean ( eng , server_func , data_eval , path_output )
2024-10-05 01:23:02 -07:00
dt = datetime . timedelta ( seconds = time . time ( ) - start_time )
print ( colored ( f " Time elapsed: { dt } " , " magenta " ) )
2024-07-11 15:49:37 -07:00
# - End run
2024-09-26 20:20:32 -07:00
# wandb.config.update(config)
# print(f"{wandb.config=}")
# run.finish()
2024-07-11 15:49:37 -07:00
2024-10-06 23:29:14 -07:00
def stat ( args ) :
path_output = Path ( args . output )
data = load_data ( args )
n_success = 0
n_tried = 0
for i , datum in tqdm ( enumerate ( data ) , total = len ( data ) , desc = ' DSP proof loop per data point in benchmark. ' ) :
file_name = path_output / f " { i : 03 } .json "
key = str ( datum )
# Detect if file exists
obj = json . load ( open ( file_name , " r " ) )
if obj [ ' name ' ] != key :
print ( colored ( f " Existing datum name { obj [ ' name ' ] } does not match dataset { key } . The output directory may be wrong " ) )
return
n_tried + = 1
if obj [ ' success ' ] :
n_success + = 1
print ( f " Proved { n_success } / { n_tried } problems " )
2024-07-11 15:49:37 -07:00
if __name__ == " __main__ " :
2024-10-03 12:03:33 -07:00
import argparse
2024-10-02 16:10:52 -07:00
parser = argparse . ArgumentParser (
prog = ' DSP ' ,
2024-10-03 12:03:33 -07:00
description = " Draft-Sketch-Prove on Lean code " ,
formatter_class = argparse . ArgumentDefaultsHelpFormatter ,
2024-10-02 16:10:52 -07:00
)
parser . add_argument (
' mode ' ,
help = " Function " ,
2024-10-06 23:29:14 -07:00
choices = [ ' eval ' , ' prompts ' , ' stat ' ] ,
2024-10-02 16:10:52 -07:00
)
parser . add_argument (
2024-10-04 21:55:47 -07:00
" --dataset " ,
2024-10-02 16:10:52 -07:00
help = " Evaluation dataset path " ,
default = experiment_dir / ' debug/toy_example1_dsp/dsp_debug5_sf/dsp_debug5_sf_train.json ' ,
)
2024-10-05 01:23:02 -07:00
parser . add_argument (
" --output " ,
help = " Result directory " ,
default = experiment_dir / ' result ' ,
)
2024-10-04 18:36:52 -07:00
parser . add_argument (
" --model " ,
help = " Model " ,
default = " gpt-4o " ,
choices = [ " gpt2 " , " gpt-3.5-turbo " , " gpt-4o " , " deepseek-ai/deepseek-math-7b-instruct " ] ,
)
2024-10-04 21:55:47 -07:00
parser . add_argument (
" --format " ,
help = " Data format " ,
default = " default " ,
choices = [ " default " , " minif2f " ] ,
)
2024-10-02 16:10:52 -07:00
parser . add_argument ( " --start " , default = 0 )
parser . add_argument ( " --end " , default = sys . maxsize )
2024-10-03 12:03:33 -07:00
parser . add_argument ( " --batchsize " , default = 10 , help = " putnam has 348 " )
parser . add_argument ( " --n-samples " , default = 1 , help = " num seqs to return for given prompt " )
parser . add_argument ( " --max-tokens " , default = 2048 , help = " Maximum number of tokens in one sample " )
parser . add_argument ( " --top-p " , default = 0.95 , help = " Sampling top p " )
parser . add_argument ( " --temperature " , default = 0.8 , help = " Sampling temperature " )
2024-10-02 16:10:52 -07:00
parser . add_argument ( " --verbose " , action = ' store_true ' )
args = parser . parse_args ( )
if args . mode == " eval " :
2024-10-03 12:03:33 -07:00
main ( args )
2024-10-06 23:29:14 -07:00
elif args . mode == ' stat ' :
stat ( args )
2024-10-02 16:10:52 -07:00
elif args . mode == " prompts " :
prompt = get_prompt_sketch_template_4_lean_v0 ( verbose = args . verbose )
print ( prompt )
else :
raise ValueError ( f " Unknown mode: { args . mode } " )