feat: Add limit on goal tactic trials
This commit is contained in:
parent
e6421dafc3
commit
7b9829e3d2
|
@ -1,6 +1,7 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
import subprocess, json, argparse
|
import subprocess, json, argparse
|
||||||
|
from typing import Optional
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pantograph.server import Server
|
from pantograph.server import Server
|
||||||
from pantograph.search import SearchResult
|
from pantograph.search import SearchResult
|
||||||
|
@ -16,10 +17,16 @@ def read_test_data(use_valid: bool):
|
||||||
with open(jsonl_path, 'r') as f:
|
with open(jsonl_path, 'r') as f:
|
||||||
return [json.loads(l) for l in list(f)]
|
return [json.loads(l) for l in list(f)]
|
||||||
|
|
||||||
def try_test_data(server, agent, entry: dict, max_steps: int) -> SearchResult:
|
def try_test_data(server, agent, entry: dict, max_steps: int) -> Optional[SearchResult]:
|
||||||
e = entry["formal_statement"]
|
e = entry["formal_statement"]
|
||||||
informal_stmt = entry["informal_stmt"]
|
informal_stmt = entry["informal_stmt"]
|
||||||
informal_proof = entry["informal_proof"]
|
informal_proof = entry["informal_proof"]
|
||||||
|
|
||||||
|
key_position = e.find('theorem')
|
||||||
|
if key_position == -1:
|
||||||
|
# Can't output anything for this one
|
||||||
|
return None
|
||||||
|
e = e[key_position:]
|
||||||
key_theorem, name, e = e.split(' ', 2)
|
key_theorem, name, e = e.split(' ', 2)
|
||||||
e, tail = e.split(':=', 1)
|
e, tail = e.split(':=', 1)
|
||||||
target = "forall " + ','.join(e.rsplit(':', 1))
|
target = "forall " + ','.join(e.rsplit(':', 1))
|
||||||
|
@ -45,7 +52,7 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--use-hammer', action='store_true')
|
parser.add_argument('--use-hammer', action='store_true')
|
||||||
parser.add_argument('--validation', action='store_true')
|
parser.add_argument('--validation', action='store_true')
|
||||||
parser.add_argument('--use-llm', action='store_true')
|
parser.add_argument('--use-llm', action='store_true')
|
||||||
parser.add_argument('-s', '--max-steps', default=1000)
|
parser.add_argument('-s', '--max-steps', default=200)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
project_path, lean_path = get_project_and_lean_path()
|
project_path, lean_path = get_project_and_lean_path()
|
||||||
|
@ -56,7 +63,14 @@ if __name__ == '__main__':
|
||||||
server = Server(imports=["Mathlib"], project_path=project_path, lean_path=lean_path)
|
server = Server(imports=["Mathlib"], project_path=project_path, lean_path=lean_path)
|
||||||
agent = LLMAgent(server, use_hammer=args.use_hammer, use_llm=args.use_llm)
|
agent = LLMAgent(server, use_hammer=args.use_hammer, use_llm=args.use_llm)
|
||||||
for datum in test_data:
|
for datum in test_data:
|
||||||
result = try_test_data(server, agent, datum, max_steps=args.max_steps)
|
|
||||||
file_name = output_file_name(datum, args.use_hammer, args.use_llm)
|
file_name = output_file_name(datum, args.use_hammer, args.use_llm)
|
||||||
|
if file_name.is_file():
|
||||||
|
print(f"Skipping {datum['id']}")
|
||||||
|
continue
|
||||||
|
result = try_test_data(server, agent, datum, max_steps=args.max_steps)
|
||||||
|
if result is None:
|
||||||
|
with open(file_name + '-placeholder', 'w') as f:
|
||||||
|
json.dump({ 'id': datum['id'] }, f)
|
||||||
|
else:
|
||||||
with open(file_name, 'w') as f:
|
with open(file_name, 'w') as f:
|
||||||
json.dump({ 'id': datum['id'], 'success': result.success, 'steps': result.steps }, f)
|
json.dump({ 'id': datum['id'], 'success': result.success, 'steps': result.steps }, f)
|
||||||
|
|
|
@ -17,6 +17,13 @@ class SearchState:
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
assert len(self.priorities) == len(self.state.goals)
|
assert len(self.priorities) == len(self.state.goals)
|
||||||
self.solved = [False for _ in self.state.goals]
|
self.solved = [False for _ in self.state.goals]
|
||||||
|
self.trials = [0 for _ in self.state.goals]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def next_goal_id(self) -> int:
|
||||||
|
goal_id, _ = max([(i, prio) for i, prio in enumerate(self.priorities) if not self.solved[i]],
|
||||||
|
key=lambda x:x[1])
|
||||||
|
return goal_id
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_root(self) -> bool:
|
def is_root(self) -> bool:
|
||||||
|
@ -55,7 +62,8 @@ class Agent:
|
||||||
target: Expr,
|
target: Expr,
|
||||||
informal_stmt: str = "",
|
informal_stmt: str = "",
|
||||||
informal_proof: str = "",
|
informal_proof: str = "",
|
||||||
max_steps: int = 1000,
|
max_steps: int = 100,
|
||||||
|
max_trial_per_goal: int = 5,
|
||||||
verbose: bool = False) -> SearchResult:
|
verbose: bool = False) -> SearchResult:
|
||||||
|
|
||||||
search_stack = [SearchState(state=server.goal_start(target),
|
search_stack = [SearchState(state=server.goal_start(target),
|
||||||
|
@ -89,11 +97,15 @@ class Agent:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Find the unsolved goal with the highest priority
|
# Find the unsolved goal with the highest priority
|
||||||
goal_id, _ = max([(i, prio) for i, prio in enumerate(search_state.priorities) if not search_state.solved[i]],
|
goal_id = search_state.next_goal_id
|
||||||
key=lambda x:x[1])
|
|
||||||
|
|
||||||
|
if search_state.trials[goal_id] > max_trial_per_goal:
|
||||||
|
# force halt the search
|
||||||
|
tactic = None
|
||||||
|
else:
|
||||||
# Generate tactic for this goal
|
# Generate tactic for this goal
|
||||||
tactic = self.next_tactic(search_state.state, goal_id, informal_stmt, informal_proof)
|
tactic = self.next_tactic(search_state.state, goal_id, informal_stmt, informal_proof)
|
||||||
|
|
||||||
print("????next tactic: ", tactic)
|
print("????next tactic: ", tactic)
|
||||||
if not tactic:
|
if not tactic:
|
||||||
# pop the current state and continue to the next
|
# pop the current state and continue to the next
|
||||||
|
|
|
@ -98,6 +98,7 @@ class Server:
|
||||||
def goal_start(self, expr: Expr) -> GoalState:
|
def goal_start(self, expr: Expr) -> GoalState:
|
||||||
result = self.run('goal.start', {"expr": str(expr)})
|
result = self.run('goal.start', {"expr": str(expr)})
|
||||||
if "error" in result:
|
if "error" in result:
|
||||||
|
print(f"Cannot start goal: {expr}")
|
||||||
raise ServerError(result["desc"])
|
raise ServerError(result["desc"])
|
||||||
return GoalState(state_id=result["stateId"], goals=[Goal.sentence(expr)], _sentinel=self.to_remove_goal_states)
|
return GoalState(state_id=result["stateId"], goals=[Goal.sentence(expr)], _sentinel=self.to_remove_goal_states)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue