feat: Handle max trials per goal and theorem formatting

This commit is contained in:
Leni Aniva 2024-06-05 15:20:36 -07:00
parent 7b9829e3d2
commit 20b19c8e6c
Signed by: aniva
GPG Key ID: 4D9B1C8D10EA4C50
3 changed files with 47 additions and 14 deletions

View File

@ -3,7 +3,7 @@
import subprocess, json, argparse import subprocess, json, argparse
from typing import Optional from typing import Optional
from pathlib import Path from pathlib import Path
from pantograph.server import Server from pantograph.server import Server, ServerError
from pantograph.search import SearchResult from pantograph.search import SearchResult
from pantograph.search_llm import LLMAgent from pantograph.search_llm import LLMAgent
@ -17,22 +17,47 @@ def read_test_data(use_valid: bool):
with open(jsonl_path, 'r') as f: with open(jsonl_path, 'r') as f:
return [json.loads(l) for l in list(f)] return [json.loads(l) for l in list(f)]
def try_test_data(server, agent, entry: dict, max_steps: int) -> Optional[SearchResult]: def inplace_to_statement(expr: str) -> str:
bracket = 0
i = 0
while i < len(expr):
if expr[i] == ':' and bracket == 0:
break
elif expr[i] == '(':
bracket += 1
elif expr[i] == ')':
bracket -= 1
i += 1
if i == 0:
return expr[1:]
if i == len(expr):
return expr
return 'forall ' + expr[:i] + ' , ' + expr[i+1:]
def try_test_data(server, agent, entry: dict, max_steps: int, max_trials_per_goal: int) -> Optional[SearchResult]:
e = entry["formal_statement"] e = entry["formal_statement"]
print(e)
informal_stmt = entry["informal_stmt"] informal_stmt = entry["informal_stmt"]
informal_proof = entry["informal_proof"] informal_proof = entry["informal_proof"]
key_position = e.find('theorem') key_position = e.find('theorem')
if key_position == -1: if key_position != 0:
# Can't output anything for this one # Can't output anything for this one
return None return None
e = e[key_position:] e = e[key_position:]
# remove the tail := sorry
e, tail = e.rsplit(':=', 1)
# remove the head
key_theorem, name, e = e.split(' ', 2) key_theorem, name, e = e.split(' ', 2)
e, tail = e.split(':=', 1) target = inplace_to_statement(e.strip())
target = "forall " + ','.join(e.rsplit(':', 1))
print(f"Target: {target}") print(f"Target: {target}")
agent = LLMAgent(server) try:
return agent.search(server=server, target=target, informal_stmt = informal_stmt, informal_proof = informal_proof,verbose=True, max_steps=max_steps) return agent.search(server=server, target=target, informal_stmt = informal_stmt, informal_proof = informal_proof,verbose=True,
max_steps=max_steps, max_trials_per_goal=max_trials_per_goal)
except ServerError as e:
return None
def output_file_name(datum, use_hammer: bool, use_llm: bool): def output_file_name(datum, use_hammer: bool, use_llm: bool):
name = datum["id"] name = datum["id"]
@ -53,6 +78,7 @@ if __name__ == '__main__':
parser.add_argument('--validation', action='store_true') parser.add_argument('--validation', action='store_true')
parser.add_argument('--use-llm', action='store_true') parser.add_argument('--use-llm', action='store_true')
parser.add_argument('-s', '--max-steps', default=200) parser.add_argument('-s', '--max-steps', default=200)
parser.add_argument('-t', '--max-trials-per-goal', default=4)
args = parser.parse_args() args = parser.parse_args()
project_path, lean_path = get_project_and_lean_path() project_path, lean_path = get_project_and_lean_path()
@ -60,17 +86,20 @@ if __name__ == '__main__':
print(f"$LEAN_PATH: {lean_path}") print(f"$LEAN_PATH: {lean_path}")
test_data = read_test_data(args.validation) test_data = read_test_data(args.validation)
server = Server(imports=["Mathlib"], project_path=project_path, lean_path=lean_path)
agent = LLMAgent(server, use_hammer=args.use_hammer, use_llm=args.use_llm)
for datum in test_data: for datum in test_data:
file_name = output_file_name(datum, args.use_hammer, args.use_llm) file_name = output_file_name(datum, args.use_hammer, args.use_llm)
placeholder_file_name = file_name.with_suffix('.placeholder')
if file_name.is_file(): if file_name.is_file():
print(f"Skipping {datum['id']}") print(f"Skipping {datum['id']}")
continue continue
result = try_test_data(server, agent, datum, max_steps=args.max_steps) server = Server(imports=["Example"], project_path=project_path, lean_path=lean_path, options=["maxHeartbeats=0"])
agent = LLMAgent(server, use_hammer=args.use_hammer, use_llm=args.use_llm)
result = try_test_data(server, agent, datum, max_steps=args.max_steps, max_trials_per_goal=args.max_trials_per_goal)
if result is None: if result is None:
with open(file_name + '-placeholder', 'w') as f: with open(placeholder_file_name, 'w') as f:
json.dump({ 'id': datum['id'] }, f) json.dump({ 'id': datum['id'] }, f)
else: else:
if placeholder_file_name.is_file():
placeholder_file_name.unlink()
with open(file_name, 'w') as f: with open(file_name, 'w') as f:
json.dump({ 'id': datum['id'], 'success': result.success, 'steps': result.steps }, f) json.dump({ 'id': datum['id'], 'success': result.success, 'steps': result.steps }, f)

View File

@ -63,7 +63,7 @@ class Agent:
informal_stmt: str = "", informal_stmt: str = "",
informal_proof: str = "", informal_proof: str = "",
max_steps: int = 100, max_steps: int = 100,
max_trial_per_goal: int = 5, max_trials_per_goal: int = 5,
verbose: bool = False) -> SearchResult: verbose: bool = False) -> SearchResult:
search_stack = [SearchState(state=server.goal_start(target), search_stack = [SearchState(state=server.goal_start(target),
@ -99,7 +99,7 @@ class Agent:
# Find the unsolved goal with the highest priority # Find the unsolved goal with the highest priority
goal_id = search_state.next_goal_id goal_id = search_state.next_goal_id
if search_state.trials[goal_id] > max_trial_per_goal: if search_state.trials[goal_id] > max_trials_per_goal:
# force halt the search # force halt the search
tactic = None tactic = None
else: else:
@ -118,6 +118,7 @@ class Agent:
continue continue
try: try:
search_state.trials[goal_id] += 1
state = search_state.state state = search_state.state
if verbose: if verbose:
print(f"{state.state_id}.{goal_id}: {tactic} on {search_state.state.goals[goal_id]}") print(f"{state.state_id}.{goal_id}: {tactic} on {search_state.state.goals[goal_id]}")

View File

@ -13,6 +13,8 @@ class LLMAgent(Agent):
super().__init__() super().__init__()
self.n_trials = 5 self.n_trials = 5
self.server = server self.server = server
if use_llm:
sgl.set_default_backend(sgl.OpenAI("gpt-4")) sgl.set_default_backend(sgl.OpenAI("gpt-4"))
self.goal_tactic_id_map = collections.defaultdict(lambda : 0) self.goal_tactic_id_map = collections.defaultdict(lambda : 0)
@ -37,6 +39,7 @@ class LLMAgent(Agent):
if i >= len(self.tactics) and not self.use_llm: if i >= len(self.tactics) and not self.use_llm:
return None return None
elif i >= len(self.tactics): elif i >= len(self.tactics):
assert self.use_llm
new_state = None new_state = None
for ii in range(self.n_trials): for ii in range(self.n_trials):
print(f"===============trail {str(ii)}============") print(f"===============trail {str(ii)}============")