Pantograph/examples_search/miniF2F_search.py

62 lines
2.4 KiB
Python
Raw Normal View History

2024-06-05 03:52:43 -07:00
#!/usr/bin/env python3
2024-06-05 14:19:18 -07:00
import subprocess, json, argparse
2024-06-05 03:52:43 -07:00
from pathlib import Path
from pantograph.server import Server
2024-06-05 14:19:18 -07:00
from pantograph.search import SearchResult
2024-06-05 03:52:43 -07:00
from pantograph.search_llm import LLMAgent
def get_project_and_lean_path():
cwd = Path(__file__).parent.resolve() / 'Example'
p = subprocess.check_output(['lake', 'env', 'printenv', 'LEAN_PATH'], cwd=cwd)
return cwd, p
2024-06-05 11:19:12 -07:00
def read_test_data():
jsonl_path = Path(__file__).parent / 'test.jsonl'
with open(jsonl_path, 'r') as f:
return [json.loads(l) for l in list(f)]
2024-06-05 14:19:18 -07:00
def try_test_data(server, agent, entry: dict, max_steps: int) -> SearchResult:
2024-06-05 11:19:12 -07:00
e = entry["formal_statement"]
2024-06-05 11:39:08 -07:00
informal_stmt = entry["informal_stmt"]
informal_proof = entry["informal_proof"]
2024-06-05 11:19:12 -07:00
key_theorem, name, e = e.split(' ', 2)
e, tail = e.split(':=', 1)
target = "forall " + ','.join(e.rsplit(':', 1))
print(f"Target: {target}")
agent = LLMAgent(server)
2024-06-05 14:19:18 -07:00
return agent.search(server=server, target=target, informal_stmt = informal_stmt, informal_proof = informal_proof,verbose=True, max_steps=max_steps)
def output_file_name(datum, use_hammer: bool, use_llm: bool):
name = datum["id"]
folder = 'output'
if use_hammer:
folder += '-hammer'
if use_llm:
folder += '-llm'
folder = Path(__file__).parent / folder
folder.mkdir(exist_ok=True, parents=True)
return folder / f"{name}.json"
2024-06-05 11:19:12 -07:00
2024-06-05 03:52:43 -07:00
if __name__ == '__main__':
2024-06-05 14:19:18 -07:00
parser = argparse.ArgumentParser(
prog='MiniF2F Search',
description='Executes LLM on MiniF2F Search')
parser.add_argument('--use-hammer', action='store_true')
parser.add_argument('--use-llm', action='store_true')
parser.add_argument('-s', '--max-steps', default=1000)
args = parser.parse_args()
2024-06-05 03:52:43 -07:00
project_path, lean_path = get_project_and_lean_path()
print(f"$PWD: {project_path}")
print(f"$LEAN_PATH: {lean_path}")
2024-06-05 11:19:12 -07:00
test_data = read_test_data()
2024-06-05 03:52:43 -07:00
server = Server(imports=["Mathlib"], project_path=project_path, lean_path=lean_path)
2024-06-05 14:19:18 -07:00
agent = LLMAgent(server, use_hammer=args.use_hammer, use_llm=args.use_llm)
for datum in test_data[:1]:
result = try_test_data(server, agent, datum, max_steps=args.max_steps)
file_name = output_file_name(datum, args.use_hammer, args.use_llm)
with open(file_name, 'w') as f:
json.dump({ 'id': datum['id'], 'success': result.success, 'steps': result.steps }, f)