diff --git a/experiments/minif2f/README.md b/experiments/minif2f/README.md index 560eb4d..5a4e92f 100644 --- a/experiments/minif2f/README.md +++ b/experiments/minif2f/README.md @@ -1,10 +1,17 @@ # MiniF2F This is an experiment on running a LLM prover on miniF2F data. Build the project -`MiniF2F` with `lake build`, and run with +`MiniF2F` with `lake build`. Check the environment and data with + +``` sh +python3 experiments/minif2f/main.py check +python3 experiments/minif2f/main.py list +``` + +and run experiments with ```sh -python3 experiments/minif2f/main.py [--dry-run] [--use-llm] +python3 experiments/minif2f/main.py eval [--use-llm] [--use-hammer] ``` Read the help message carefully. diff --git a/experiments/minif2f/main.py b/experiments/minif2f/main.py index 1ddf467..a2726f2 100755 --- a/experiments/minif2f/main.py +++ b/experiments/minif2f/main.py @@ -29,10 +29,12 @@ def try_test_data(server, agent, entry: dict, max_steps: int, max_trials_per_goa goal_states = server.load_sorry(command) - if len(goal_states) == 0: + if len(goal_states) != 1: return None goal_state, = goal_states + if isinstance(goal_state, list): + return None try: return agent.search( server=server, @@ -103,16 +105,28 @@ def run_eval(args): with open(file_name, 'w') as f: json.dump({ 'id': datum['id'], 'success': result.success, 'steps': result.steps }, f) +def run_check(args): + project_path, lean_path = get_project_and_lean_path() + print(f"$PWD: {project_path}") + print(f"$LEAN_PATH: {lean_path}") + server = Server( + imports=["Mathlib", "Aesop"], + project_path=project_path, + lean_path=lean_path, + core_options=CORE_OPTIONS, + ) + if __name__ == '__main__': parser = argparse.ArgumentParser( prog='MiniF2F Search', description='Executes LLM on MiniF2F Search', ) - parser.add_argument('--use-hammer', action='store_true') parser.add_argument( - '--dry-run', - action='store_true', - help="List the data used, but don't run") + 'mode', + help='Function', + choices=['list', 'eval', 'check'], + ) + parser.add_argument('--use-hammer', action='store_true') parser.add_argument('--validation', action='store_true') parser.add_argument('--use-llm', action='store_true') parser.add_argument('--max-steps', default=50) @@ -120,7 +134,11 @@ if __name__ == '__main__': parser.add_argument('--feedback-turns', default=2) args = parser.parse_args() - if args.dry_run: + if args.mode == "list": dry_run(args) - else: + elif args.mode == "eval": run_eval(args) + elif args.mode == "check": + run_check(args) + else: + raise ValueError(f"Invalid mode: {args.mode}") diff --git a/experiments/minif2f/model/llm_agent.py b/experiments/minif2f/model/llm_agent.py index 9c069a8..14c3b39 100644 --- a/experiments/minif2f/model/llm_agent.py +++ b/experiments/minif2f/model/llm_agent.py @@ -1,5 +1,6 @@ from typing import Optional import collections, unittest +from termcolor import colored from pantograph.search import Agent from pantograph.server import Server, TacticFailure, ServerError from pantograph.expr import Expr, Tactic, GoalState @@ -57,20 +58,27 @@ class LLMAgent(Agent): new_state = None for ii in range(self.n_trials): print(f"===============trail {str(ii)}============") - s = select_tactic.run( - server=self.server, - state=state, - goal_id=goal_id, - informal_stmt=self.informal_stmt, - informal_proof=self.informal_proof, - feedback_turns=self.feedback_turns) - tactic, new_state = s.ret_value - for m in s.messages(): - print(m["role"], ":", m["content"]) + try: + s = select_tactic.run( + server=self.server, + state=state, + goal_id=goal_id, + informal_stmt=self.informal_stmt, + informal_proof=self.informal_proof, + feedback_turns=self.feedback_turns) + tactic, new_state = s.ret_value + for m in s.messages(): + print(m["role"], ":", m["content"]) - print("\n-- new state --\n", new_state) - if tactic: - return tactic + print("\n-- new state --\n", new_state) + if tactic: + if not isinstance(tactic, Tactic): + print(colored("[Tactic] Failed:", "red"), tactic) + return None + return tactic + except Exception as e: + print(colored(str(e), "red")) + return None return None else: self.goal_tactic_id_map[key] = i + 1