feat: Add limit on goal tactic trials

This commit is contained in:
Leni Aniva 2024-06-05 14:36:51 -07:00
parent e6421dafc3
commit 7b9829e3d2
Signed by: aniva
GPG Key ID: 4D9B1C8D10EA4C50
3 changed files with 37 additions and 10 deletions

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import subprocess, json, argparse import subprocess, json, argparse
from typing import Optional
from pathlib import Path from pathlib import Path
from pantograph.server import Server from pantograph.server import Server
from pantograph.search import SearchResult from pantograph.search import SearchResult
@ -16,10 +17,16 @@ def read_test_data(use_valid: bool):
with open(jsonl_path, 'r') as f: with open(jsonl_path, 'r') as f:
return [json.loads(l) for l in list(f)] return [json.loads(l) for l in list(f)]
def try_test_data(server, agent, entry: dict, max_steps: int) -> SearchResult: def try_test_data(server, agent, entry: dict, max_steps: int) -> Optional[SearchResult]:
e = entry["formal_statement"] e = entry["formal_statement"]
informal_stmt = entry["informal_stmt"] informal_stmt = entry["informal_stmt"]
informal_proof = entry["informal_proof"] informal_proof = entry["informal_proof"]
key_position = e.find('theorem')
if key_position == -1:
# Can't output anything for this one
return None
e = e[key_position:]
key_theorem, name, e = e.split(' ', 2) key_theorem, name, e = e.split(' ', 2)
e, tail = e.split(':=', 1) e, tail = e.split(':=', 1)
target = "forall " + ','.join(e.rsplit(':', 1)) target = "forall " + ','.join(e.rsplit(':', 1))
@ -45,7 +52,7 @@ if __name__ == '__main__':
parser.add_argument('--use-hammer', action='store_true') parser.add_argument('--use-hammer', action='store_true')
parser.add_argument('--validation', action='store_true') parser.add_argument('--validation', action='store_true')
parser.add_argument('--use-llm', action='store_true') parser.add_argument('--use-llm', action='store_true')
parser.add_argument('-s', '--max-steps', default=1000) parser.add_argument('-s', '--max-steps', default=200)
args = parser.parse_args() args = parser.parse_args()
project_path, lean_path = get_project_and_lean_path() project_path, lean_path = get_project_and_lean_path()
@ -56,7 +63,14 @@ if __name__ == '__main__':
server = Server(imports=["Mathlib"], project_path=project_path, lean_path=lean_path) server = Server(imports=["Mathlib"], project_path=project_path, lean_path=lean_path)
agent = LLMAgent(server, use_hammer=args.use_hammer, use_llm=args.use_llm) agent = LLMAgent(server, use_hammer=args.use_hammer, use_llm=args.use_llm)
for datum in test_data: for datum in test_data:
result = try_test_data(server, agent, datum, max_steps=args.max_steps)
file_name = output_file_name(datum, args.use_hammer, args.use_llm) file_name = output_file_name(datum, args.use_hammer, args.use_llm)
if file_name.is_file():
print(f"Skipping {datum['id']}")
continue
result = try_test_data(server, agent, datum, max_steps=args.max_steps)
if result is None:
with open(file_name + '-placeholder', 'w') as f:
json.dump({ 'id': datum['id'] }, f)
else:
with open(file_name, 'w') as f: with open(file_name, 'w') as f:
json.dump({ 'id': datum['id'], 'success': result.success, 'steps': result.steps }, f) json.dump({ 'id': datum['id'], 'success': result.success, 'steps': result.steps }, f)

View File

@ -17,6 +17,13 @@ class SearchState:
def __post_init__(self): def __post_init__(self):
assert len(self.priorities) == len(self.state.goals) assert len(self.priorities) == len(self.state.goals)
self.solved = [False for _ in self.state.goals] self.solved = [False for _ in self.state.goals]
self.trials = [0 for _ in self.state.goals]
@property
def next_goal_id(self) -> int:
goal_id, _ = max([(i, prio) for i, prio in enumerate(self.priorities) if not self.solved[i]],
key=lambda x:x[1])
return goal_id
@property @property
def is_root(self) -> bool: def is_root(self) -> bool:
@ -55,7 +62,8 @@ class Agent:
target: Expr, target: Expr,
informal_stmt: str = "", informal_stmt: str = "",
informal_proof: str = "", informal_proof: str = "",
max_steps: int = 1000, max_steps: int = 100,
max_trial_per_goal: int = 5,
verbose: bool = False) -> SearchResult: verbose: bool = False) -> SearchResult:
search_stack = [SearchState(state=server.goal_start(target), search_stack = [SearchState(state=server.goal_start(target),
@ -89,11 +97,15 @@ class Agent:
continue continue
# Find the unsolved goal with the highest priority # Find the unsolved goal with the highest priority
goal_id, _ = max([(i, prio) for i, prio in enumerate(search_state.priorities) if not search_state.solved[i]], goal_id = search_state.next_goal_id
key=lambda x:x[1])
if search_state.trials[goal_id] > max_trial_per_goal:
# force halt the search
tactic = None
else:
# Generate tactic for this goal # Generate tactic for this goal
tactic = self.next_tactic(search_state.state, goal_id, informal_stmt, informal_proof) tactic = self.next_tactic(search_state.state, goal_id, informal_stmt, informal_proof)
print("????next tactic: ", tactic) print("????next tactic: ", tactic)
if not tactic: if not tactic:
# pop the current state and continue to the next # pop the current state and continue to the next

View File

@ -98,6 +98,7 @@ class Server:
def goal_start(self, expr: Expr) -> GoalState: def goal_start(self, expr: Expr) -> GoalState:
result = self.run('goal.start', {"expr": str(expr)}) result = self.run('goal.start', {"expr": str(expr)})
if "error" in result: if "error" in result:
print(f"Cannot start goal: {expr}")
raise ServerError(result["desc"]) raise ServerError(result["desc"])
return GoalState(state_id=result["stateId"], goals=[Goal.sentence(expr)], _sentinel=self.to_remove_goal_states) return GoalState(state_id=result["stateId"], goals=[Goal.sentence(expr)], _sentinel=self.to_remove_goal_states)