From 789452f7b79175614fce4821a174e12d5a486871 Mon Sep 17 00:00:00 2001 From: Leni Aniva Date: Sun, 6 Oct 2024 23:29:14 -0700 Subject: [PATCH] feat: Add stat function to show prove rate --- experiments/dsp/main.py | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/experiments/dsp/main.py b/experiments/dsp/main.py index 00ae10d..60b68b6 100644 --- a/experiments/dsp/main.py +++ b/experiments/dsp/main.py @@ -276,6 +276,7 @@ def full_proof_search_dsp_lean( ): print(colored(f"DSP on {len(data)} points", "blue", attrs=["bold", "underline"])) n_success = 0 + n_tried = 0 # -- Proof search by DSP over all eval data for i, datum in tqdm(enumerate(data), total=len(data), desc='DSP proof loop per data point in benchmark.'): file_name = path_output / f"{i:03}.json" @@ -285,11 +286,12 @@ def full_proof_search_dsp_lean( obj = json.load(open(file_name, "r")) if obj['name'] != key: print(colored(f"Existing datum name {obj['name']} does not match dataset {key}. The output directory may be wrong")) - break + return print(f"Skipped {i}:", colored(key, "green")) continue + n_tried += 1 print(f"Problem {i}:", colored(key, "cyan")) result = single_proof_search_dsp_lean(eng, server_func, datum) @@ -298,7 +300,7 @@ def full_proof_search_dsp_lean( if result.success: n_success += 1 #server.gc() - print(f"Proved {n_success}/{len(data)} problems") + print(f"Proved {n_success}/{n_tried} problems") experiment_dir = Path(__file__).resolve().parent @@ -394,6 +396,25 @@ def main(args): # print(f"{wandb.config=}") # run.finish() +def stat(args): + path_output = Path(args.output) + data = load_data(args) + n_success = 0 + n_tried = 0 + for i, datum in tqdm(enumerate(data), total=len(data), desc='DSP proof loop per data point in benchmark.'): + file_name = path_output / f"{i:03}.json" + key = str(datum) + # Detect if file exists + obj = json.load(open(file_name, "r")) + if obj['name'] != key: + print(colored(f"Existing datum name {obj['name']} does not match dataset {key}. The output directory may be wrong")) + return + + n_tried += 1 + if obj['success']: + n_success += 1 + print(f"Proved {n_success}/{n_tried} problems") + if __name__ == "__main__": import argparse @@ -405,7 +426,7 @@ if __name__ == "__main__": parser.add_argument( 'mode', help="Function", - choices=['eval', 'prompts'], + choices=['eval', 'prompts', 'stat'], ) parser.add_argument( "--dataset", @@ -441,6 +462,8 @@ if __name__ == "__main__": if args.mode == "eval": main(args) + elif args.mode == 'stat': + stat(args) elif args.mode == "prompts": prompt = get_prompt_sketch_template_4_lean_v0(verbose=args.verbose) print(prompt)