fix: Filter out val/test data

This commit is contained in:
Leni Aniva 2024-10-09 18:23:21 -07:00
parent 68cb408a27
commit cd05b67c10
Signed by: aniva
GPG Key ID: 4D9B1C8D10EA4C50
2 changed files with 34 additions and 9 deletions

View File

@ -1,10 +1,17 @@
# MiniF2F
This is an experiment on running a LLM prover on miniF2F data. Build the project
`MiniF2F` with `lake build`, and run with
`MiniF2F` with `lake build`. Check the environment and data with
``` sh
python3 experiments/minif2f/main.py check
python3 experiments/minif2f/main.py list
```
and run experiments with
```sh
python3 experiments/minif2f/main.py [--dry-run] [--use-llm]
python3 experiments/minif2f/main.py eval [--use-llm] [--use-hammer]
```
Read the help message carefully.

View File

@ -29,10 +29,12 @@ def try_test_data(server, agent, entry: dict, max_steps: int, max_trials_per_goa
goal_states = server.load_sorry(command)
if len(goal_states) == 0:
if len(goal_states) != 1:
return None
goal_state, = goal_states
if isinstance(goal_state, list):
return None
try:
return agent.search(
server=server,
@ -103,16 +105,28 @@ def run_eval(args):
with open(file_name, 'w') as f:
json.dump({ 'id': datum['id'], 'success': result.success, 'steps': result.steps }, f)
def run_check(args):
project_path, lean_path = get_project_and_lean_path()
print(f"$PWD: {project_path}")
print(f"$LEAN_PATH: {lean_path}")
server = Server(
imports=["Mathlib", "Aesop"],
project_path=project_path,
lean_path=lean_path,
core_options=CORE_OPTIONS,
)
if __name__ == '__main__':
parser = argparse.ArgumentParser(
prog='MiniF2F Search',
description='Executes LLM on MiniF2F Search',
)
parser.add_argument('--use-hammer', action='store_true')
parser.add_argument(
'--dry-run',
action='store_true',
help="List the data used, but don't run")
'mode',
help='Function',
choices=['list', 'eval', 'check'],
)
parser.add_argument('--use-hammer', action='store_true')
parser.add_argument('--validation', action='store_true')
parser.add_argument('--use-llm', action='store_true')
parser.add_argument('--max-steps', default=50)
@ -120,7 +134,11 @@ if __name__ == '__main__':
parser.add_argument('--feedback-turns', default=2)
args = parser.parse_args()
if args.dry_run:
if args.mode == "list":
dry_run(args)
else:
elif args.mode == "eval":
run_eval(args)
elif args.mode == "check":
run_check(args)
else:
raise ValueError(f"Invalid mode: {args.mode}")