diff --git a/examples_search/miniF2F_search.py b/examples_search/miniF2F_search.py index d9e4b2a..0bde569 100755 --- a/examples_search/miniF2F_search.py +++ b/examples_search/miniF2F_search.py @@ -11,8 +11,8 @@ def get_project_and_lean_path(): p = subprocess.check_output(['lake', 'env', 'printenv', 'LEAN_PATH'], cwd=cwd) return cwd, p -def read_test_data(): - jsonl_path = Path(__file__).parent / 'test.jsonl' +def read_test_data(use_valid: bool): + jsonl_path = Path(__file__).parent / ('valid.jsonl' if use_valid else 'test.jsonl') with open(jsonl_path, 'r') as f: return [json.loads(l) for l in list(f)] @@ -43,6 +43,7 @@ if __name__ == '__main__': prog='MiniF2F Search', description='Executes LLM on MiniF2F Search') parser.add_argument('--use-hammer', action='store_true') + parser.add_argument('--validation', action='store_true') parser.add_argument('--use-llm', action='store_true') parser.add_argument('-s', '--max-steps', default=1000) args = parser.parse_args() @@ -51,10 +52,10 @@ if __name__ == '__main__': print(f"$PWD: {project_path}") print(f"$LEAN_PATH: {lean_path}") - test_data = read_test_data() + test_data = read_test_data(args.validation) server = Server(imports=["Mathlib"], project_path=project_path, lean_path=lean_path) agent = LLMAgent(server, use_hammer=args.use_hammer, use_llm=args.use_llm) - for datum in test_data[:1]: + for datum in test_data: result = try_test_data(server, agent, datum, max_steps=args.max_steps) file_name = output_file_name(datum, args.use_hammer, args.use_llm) with open(file_name, 'w') as f: