feat: Add stat function to show prove rate

This commit is contained in:
Leni Aniva 2024-10-06 23:29:14 -07:00
parent 402df63395
commit 789452f7b7
Signed by: aniva
GPG Key ID: 4D9B1C8D10EA4C50
1 changed files with 26 additions and 3 deletions

View File

@ -276,6 +276,7 @@ def full_proof_search_dsp_lean(
): ):
print(colored(f"DSP on {len(data)} points", "blue", attrs=["bold", "underline"])) print(colored(f"DSP on {len(data)} points", "blue", attrs=["bold", "underline"]))
n_success = 0 n_success = 0
n_tried = 0
# -- Proof search by DSP over all eval data # -- Proof search by DSP over all eval data
for i, datum in tqdm(enumerate(data), total=len(data), desc='DSP proof loop per data point in benchmark.'): for i, datum in tqdm(enumerate(data), total=len(data), desc='DSP proof loop per data point in benchmark.'):
file_name = path_output / f"{i:03}.json" file_name = path_output / f"{i:03}.json"
@ -285,11 +286,12 @@ def full_proof_search_dsp_lean(
obj = json.load(open(file_name, "r")) obj = json.load(open(file_name, "r"))
if obj['name'] != key: if obj['name'] != key:
print(colored(f"Existing datum name {obj['name']} does not match dataset {key}. The output directory may be wrong")) print(colored(f"Existing datum name {obj['name']} does not match dataset {key}. The output directory may be wrong"))
break return
print(f"Skipped {i}:", colored(key, "green")) print(f"Skipped {i}:", colored(key, "green"))
continue continue
n_tried += 1
print(f"Problem {i}:", colored(key, "cyan")) print(f"Problem {i}:", colored(key, "cyan"))
result = single_proof_search_dsp_lean(eng, server_func, datum) result = single_proof_search_dsp_lean(eng, server_func, datum)
@ -298,7 +300,7 @@ def full_proof_search_dsp_lean(
if result.success: if result.success:
n_success += 1 n_success += 1
#server.gc() #server.gc()
print(f"Proved {n_success}/{len(data)} problems") print(f"Proved {n_success}/{n_tried} problems")
experiment_dir = Path(__file__).resolve().parent experiment_dir = Path(__file__).resolve().parent
@ -394,6 +396,25 @@ def main(args):
# print(f"{wandb.config=}") # print(f"{wandb.config=}")
# run.finish() # run.finish()
def stat(args):
path_output = Path(args.output)
data = load_data(args)
n_success = 0
n_tried = 0
for i, datum in tqdm(enumerate(data), total=len(data), desc='DSP proof loop per data point in benchmark.'):
file_name = path_output / f"{i:03}.json"
key = str(datum)
# Detect if file exists
obj = json.load(open(file_name, "r"))
if obj['name'] != key:
print(colored(f"Existing datum name {obj['name']} does not match dataset {key}. The output directory may be wrong"))
return
n_tried += 1
if obj['success']:
n_success += 1
print(f"Proved {n_success}/{n_tried} problems")
if __name__ == "__main__": if __name__ == "__main__":
import argparse import argparse
@ -405,7 +426,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
'mode', 'mode',
help="Function", help="Function",
choices=['eval', 'prompts'], choices=['eval', 'prompts', 'stat'],
) )
parser.add_argument( parser.add_argument(
"--dataset", "--dataset",
@ -441,6 +462,8 @@ if __name__ == "__main__":
if args.mode == "eval": if args.mode == "eval":
main(args) main(args)
elif args.mode == 'stat':
stat(args)
elif args.mode == "prompts": elif args.mode == "prompts":
prompt = get_prompt_sketch_template_4_lean_v0(verbose=args.verbose) prompt = get_prompt_sketch_template_4_lean_v0(verbose=args.verbose)
print(prompt) print(prompt)