diff --git a/experiments/dsp/.gitignore b/experiments/dsp/.gitignore new file mode 100644 index 0000000..c4a847d --- /dev/null +++ b/experiments/dsp/.gitignore @@ -0,0 +1 @@ +/result diff --git a/experiments/dsp/main.py b/experiments/dsp/main.py index 6f3ee1c..1dd5d26 100644 --- a/experiments/dsp/main.py +++ b/experiments/dsp/main.py @@ -213,13 +213,16 @@ def full_proof_search_dsp_lean( eng: Engine, server: Server, data: list[Datum], + path_output: Path, ): print(colored(f"DSP on {len(data)} points", "blue", attrs=["bold", "underline"])) # -- Proof search by DSP over all eval data - for datum in tqdm(data, total=len(data), desc='DSP proof loop per data point in benchmark.'): - print("Problem:", colored(datum.id, "cyan")) + for i, datum in tqdm(enumerate(data), total=len(data), desc='DSP proof loop per data point in benchmark.'): + print(f"Problem {i}:", colored(str(datum), "cyan")) result = single_proof_search_dsp_lean(eng, server, datum) - print(result) + file_name = path_output / f"{i:03}.json" + with open(file_name, 'w') as f: + json.dump({ 'name': str(datum), 'success': result.success, 'steps': result.steps }, f) #server.gc() return @@ -253,9 +256,13 @@ def load_data(args) -> list[Datum]: # -- Main def main(args): - import time + import time, datetime start_time = time.time() + + # Setup paths and data data_eval = load_data(args) + path_output = Path(args.output) + path_output.mkdir(exist_ok=True, parents=True) # Start server project_path, lean_path = get_project_and_lean_path() @@ -302,9 +309,10 @@ def main(args): ) # - Full proof search with DSP - full_proof_search_dsp_lean(eng, server, data_eval) - msg = f"Time taken: {time.time() - start_time:.2f} seconds, or {(time.time() - start_time) / 60:.2f} minutes, or {(time.time() - start_time) / 3600:.2f} hours.\a" - print(colored(msg, "magenta")) + full_proof_search_dsp_lean(eng, server, data_eval, path_output) + + dt = datetime.timedelta(seconds=time.time() - start_time) + print(colored(f"Time elapsed: {dt}", "magenta")) # - End run # wandb.config.update(config) @@ -329,6 +337,11 @@ if __name__ == "__main__": help="Evaluation dataset path", default=experiment_dir / 'debug/toy_example1_dsp/dsp_debug5_sf/dsp_debug5_sf_train.json', ) + parser.add_argument( + "--output", + help="Result directory", + default=experiment_dir / 'result', + ) parser.add_argument( "--model", help="Model", diff --git a/experiments/dsp/solve/data.py b/experiments/dsp/solve/data.py index 64389c6..2cf4e94 100644 --- a/experiments/dsp/solve/data.py +++ b/experiments/dsp/solve/data.py @@ -20,7 +20,7 @@ class Datum: def __str__(self): if self.id: return self.id - return str(self.nl_problem) + return self.nl_problem_str @property def nl_problem_str(self) -> Optional[str]: