Merge pull request #24 from lenianiva/experiments/minif2f

experiment: MiniF2F speedup
This commit is contained in:
Leni Aniva 2024-10-13 19:21:16 -07:00 committed by GitHub
commit 8585e3dd9e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 55 additions and 22 deletions

View File

@ -1,10 +1,17 @@
# MiniF2F # MiniF2F
This is an experiment on running a LLM prover on miniF2F data. Build the project This is an experiment on running a LLM prover on miniF2F data. Build the project
`MiniF2F` with `lake build`, and run with `MiniF2F` with `lake build`. Check the environment and data with
``` sh
python3 experiments/minif2f/main.py check
python3 experiments/minif2f/main.py list
```
and run experiments with
```sh ```sh
python3 experiments/minif2f/main.py [--dry-run] [--use-llm] python3 experiments/minif2f/main.py eval [--use-llm] [--use-hammer]
``` ```
Read the help message carefully. Read the help message carefully.

View File

@ -29,10 +29,12 @@ def try_test_data(server, agent, entry: dict, max_steps: int, max_trials_per_goa
goal_states = server.load_sorry(command) goal_states = server.load_sorry(command)
if len(goal_states) == 0: if len(goal_states) != 1:
return None return None
goal_state, = goal_states goal_state, = goal_states
if isinstance(goal_state, list):
return None
try: try:
return agent.search( return agent.search(
server=server, server=server,
@ -103,16 +105,28 @@ def run_eval(args):
with open(file_name, 'w') as f: with open(file_name, 'w') as f:
json.dump({ 'id': datum['id'], 'success': result.success, 'steps': result.steps }, f) json.dump({ 'id': datum['id'], 'success': result.success, 'steps': result.steps }, f)
def run_check(args):
project_path, lean_path = get_project_and_lean_path()
print(f"$PWD: {project_path}")
print(f"$LEAN_PATH: {lean_path}")
server = Server(
imports=["Mathlib", "Aesop"],
project_path=project_path,
lean_path=lean_path,
core_options=CORE_OPTIONS,
)
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
prog='MiniF2F Search', prog='MiniF2F Search',
description='Executes LLM on MiniF2F Search', description='Executes LLM on MiniF2F Search',
) )
parser.add_argument('--use-hammer', action='store_true')
parser.add_argument( parser.add_argument(
'--dry-run', 'mode',
action='store_true', help='Function',
help="List the data used, but don't run") choices=['list', 'eval', 'check'],
)
parser.add_argument('--use-hammer', action='store_true')
parser.add_argument('--validation', action='store_true') parser.add_argument('--validation', action='store_true')
parser.add_argument('--use-llm', action='store_true') parser.add_argument('--use-llm', action='store_true')
parser.add_argument('--max-steps', default=50) parser.add_argument('--max-steps', default=50)
@ -120,7 +134,11 @@ if __name__ == '__main__':
parser.add_argument('--feedback-turns', default=2) parser.add_argument('--feedback-turns', default=2)
args = parser.parse_args() args = parser.parse_args()
if args.dry_run: if args.mode == "list":
dry_run(args) dry_run(args)
else: elif args.mode == "eval":
run_eval(args) run_eval(args)
elif args.mode == "check":
run_check(args)
else:
raise ValueError(f"Invalid mode: {args.mode}")

View File

@ -1,5 +1,6 @@
from typing import Optional from typing import Optional
import collections, unittest import collections, unittest
from termcolor import colored
from pantograph.search import Agent from pantograph.search import Agent
from pantograph.server import Server, TacticFailure, ServerError from pantograph.server import Server, TacticFailure, ServerError
from pantograph.expr import Expr, Tactic, GoalState from pantograph.expr import Expr, Tactic, GoalState
@ -57,6 +58,7 @@ class LLMAgent(Agent):
new_state = None new_state = None
for ii in range(self.n_trials): for ii in range(self.n_trials):
print(f"===============trail {str(ii)}============") print(f"===============trail {str(ii)}============")
try:
s = select_tactic.run( s = select_tactic.run(
server=self.server, server=self.server,
state=state, state=state,
@ -70,7 +72,13 @@ class LLMAgent(Agent):
print("\n-- new state --\n", new_state) print("\n-- new state --\n", new_state)
if tactic: if tactic:
if not isinstance(tactic, Tactic):
print(colored("[Tactic] Failed:", "red"), tactic)
return None
return tactic return tactic
except Exception as e:
print(colored(str(e), "red"))
return None
return None return None
else: else:
self.goal_tactic_id_map[key] = i + 1 self.goal_tactic_id_map[key] = i + 1