Pantograph/experiments/minif2f/main.py

122 lines
4.1 KiB
Python
Raw Normal View History

2024-06-05 03:52:43 -07:00
#!/usr/bin/env python3
2024-06-05 14:19:18 -07:00
import subprocess, json, argparse
2024-06-05 14:36:51 -07:00
from typing import Optional
2024-06-05 03:52:43 -07:00
from pathlib import Path
from termcolor import colored
from pantograph.server import Server, ServerError, DEFAULT_CORE_OPTIONS
2024-06-05 14:19:18 -07:00
from pantograph.search import SearchResult
from model.llm_agent import LLMAgent
from model.options import CORE_OPTIONS
PATH_EXPERIMENT = Path(__file__).parent.resolve()
2024-06-05 03:52:43 -07:00
def get_project_and_lean_path():
cwd = PATH_EXPERIMENT / 'MiniF2F'
2024-06-05 03:52:43 -07:00
p = subprocess.check_output(['lake', 'env', 'printenv', 'LEAN_PATH'], cwd=cwd)
return cwd, p
2024-06-05 14:20:03 -07:00
def read_test_data(use_valid: bool):
jsonl_path = PATH_EXPERIMENT / ('valid.jsonl' if use_valid else 'test.jsonl')
2024-06-05 11:19:12 -07:00
with open(jsonl_path, 'r') as f:
return [json.loads(l) for l in list(f)]
def try_test_data(server, agent, entry: dict, max_steps: int, max_trials_per_goal: int) -> Optional[SearchResult]:
command = entry["formal_statement"]
print(command)
2024-06-05 11:39:08 -07:00
informal_stmt = entry["informal_stmt"]
informal_proof = entry["informal_proof"]
2024-06-05 14:36:51 -07:00
2024-10-03 12:58:39 -07:00
goal_states = server.load_sorry(command)
if len(goal_states) == 0:
return None
goal_state, = goal_states
try:
return agent.search(
server=server,
goal_state=goal_state,
informal_stmt=informal_stmt,
informal_proof=informal_proof,
verbose=True,
max_steps=max_steps,
max_trials_per_goal=max_trials_per_goal
)
except ServerError as e:
return None
2024-06-05 14:19:18 -07:00
def output_file_name(datum, use_hammer: bool, use_llm: bool):
name = datum["id"]
folder = 'output'
if use_hammer:
folder += '-hammer'
if use_llm:
folder += '-llm'
folder = PATH_EXPERIMENT / folder
2024-06-05 14:19:18 -07:00
folder.mkdir(exist_ok=True, parents=True)
return folder / f"{name}.json"
2024-06-05 11:19:12 -07:00
def dry_run(args):
test_data = read_test_data(args.validation)
for datum in test_data:
print(datum["formal_statement"])
2024-06-05 14:19:18 -07:00
def run_eval(args):
2024-06-05 03:52:43 -07:00
project_path, lean_path = get_project_and_lean_path()
print(f"$PWD: {project_path}")
print(f"$LEAN_PATH: {lean_path}")
2024-06-05 14:20:03 -07:00
test_data = read_test_data(args.validation)
for datum in test_data:
2024-06-05 14:19:18 -07:00
file_name = output_file_name(datum, args.use_hammer, args.use_llm)
placeholder_file_name = file_name.with_suffix('.placeholder')
2024-06-05 14:36:51 -07:00
if file_name.is_file():
print(colored(f"Skipping {datum['id']}", "green"))
2024-06-05 14:36:51 -07:00
continue
print(colored(f"Evaluating on {datum['id']} ...", "blue"))
server = Server(
imports=["Mathlib", "Aesop"],
project_path=project_path,
lean_path=lean_path,
core_options=CORE_OPTIONS,
)
agent = LLMAgent(
server,
use_hammer=args.use_hammer,
use_llm=args.use_llm,
feedback_turns=args.feedback_turns,
)
result = try_test_data(server, agent, datum, max_steps=args.max_steps, max_trials_per_goal=args.max_trials_per_goal)
2024-10-03 15:45:14 -07:00
#server.gc()
2024-06-05 14:36:51 -07:00
if result is None:
with open(placeholder_file_name, 'w') as f:
2024-06-05 14:36:51 -07:00
json.dump({ 'id': datum['id'] }, f)
else:
if placeholder_file_name.is_file():
placeholder_file_name.unlink()
2024-06-05 14:36:51 -07:00
with open(file_name, 'w') as f:
json.dump({ 'id': datum['id'], 'success': result.success, 'steps': result.steps }, f)
if __name__ == '__main__':
parser = argparse.ArgumentParser(
prog='MiniF2F Search',
description='Executes LLM on MiniF2F Search',
)
parser.add_argument('--use-hammer', action='store_true')
parser.add_argument(
'--dry-run',
action='store_true',
help="List the data used, but don't run")
parser.add_argument('--validation', action='store_true')
parser.add_argument('--use-llm', action='store_true')
parser.add_argument('--max-steps', default=50)
parser.add_argument('--max-trials-per-goal', default=2)
parser.add_argument('--feedback-turns', default=2)
args = parser.parse_args()
if args.dry_run:
dry_run(args)
else:
run_eval(args)