feat: Handle max trials per goal and theorem formatting
This commit is contained in:
parent
7b9829e3d2
commit
20b19c8e6c
|
@ -3,7 +3,7 @@
|
||||||
import subprocess, json, argparse
|
import subprocess, json, argparse
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pantograph.server import Server
|
from pantograph.server import Server, ServerError
|
||||||
from pantograph.search import SearchResult
|
from pantograph.search import SearchResult
|
||||||
from pantograph.search_llm import LLMAgent
|
from pantograph.search_llm import LLMAgent
|
||||||
|
|
||||||
|
@ -17,22 +17,47 @@ def read_test_data(use_valid: bool):
|
||||||
with open(jsonl_path, 'r') as f:
|
with open(jsonl_path, 'r') as f:
|
||||||
return [json.loads(l) for l in list(f)]
|
return [json.loads(l) for l in list(f)]
|
||||||
|
|
||||||
def try_test_data(server, agent, entry: dict, max_steps: int) -> Optional[SearchResult]:
|
def inplace_to_statement(expr: str) -> str:
|
||||||
|
bracket = 0
|
||||||
|
i = 0
|
||||||
|
while i < len(expr):
|
||||||
|
if expr[i] == ':' and bracket == 0:
|
||||||
|
break
|
||||||
|
elif expr[i] == '(':
|
||||||
|
bracket += 1
|
||||||
|
elif expr[i] == ')':
|
||||||
|
bracket -= 1
|
||||||
|
i += 1
|
||||||
|
if i == 0:
|
||||||
|
return expr[1:]
|
||||||
|
if i == len(expr):
|
||||||
|
return expr
|
||||||
|
|
||||||
|
return 'forall ' + expr[:i] + ' , ' + expr[i+1:]
|
||||||
|
|
||||||
|
|
||||||
|
def try_test_data(server, agent, entry: dict, max_steps: int, max_trials_per_goal: int) -> Optional[SearchResult]:
|
||||||
e = entry["formal_statement"]
|
e = entry["formal_statement"]
|
||||||
|
print(e)
|
||||||
informal_stmt = entry["informal_stmt"]
|
informal_stmt = entry["informal_stmt"]
|
||||||
informal_proof = entry["informal_proof"]
|
informal_proof = entry["informal_proof"]
|
||||||
|
|
||||||
key_position = e.find('theorem')
|
key_position = e.find('theorem')
|
||||||
if key_position == -1:
|
if key_position != 0:
|
||||||
# Can't output anything for this one
|
# Can't output anything for this one
|
||||||
return None
|
return None
|
||||||
e = e[key_position:]
|
e = e[key_position:]
|
||||||
|
# remove the tail := sorry
|
||||||
|
e, tail = e.rsplit(':=', 1)
|
||||||
|
# remove the head
|
||||||
key_theorem, name, e = e.split(' ', 2)
|
key_theorem, name, e = e.split(' ', 2)
|
||||||
e, tail = e.split(':=', 1)
|
target = inplace_to_statement(e.strip())
|
||||||
target = "forall " + ','.join(e.rsplit(':', 1))
|
|
||||||
print(f"Target: {target}")
|
print(f"Target: {target}")
|
||||||
agent = LLMAgent(server)
|
try:
|
||||||
return agent.search(server=server, target=target, informal_stmt = informal_stmt, informal_proof = informal_proof,verbose=True, max_steps=max_steps)
|
return agent.search(server=server, target=target, informal_stmt = informal_stmt, informal_proof = informal_proof,verbose=True,
|
||||||
|
max_steps=max_steps, max_trials_per_goal=max_trials_per_goal)
|
||||||
|
except ServerError as e:
|
||||||
|
return None
|
||||||
|
|
||||||
def output_file_name(datum, use_hammer: bool, use_llm: bool):
|
def output_file_name(datum, use_hammer: bool, use_llm: bool):
|
||||||
name = datum["id"]
|
name = datum["id"]
|
||||||
|
@ -53,6 +78,7 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--validation', action='store_true')
|
parser.add_argument('--validation', action='store_true')
|
||||||
parser.add_argument('--use-llm', action='store_true')
|
parser.add_argument('--use-llm', action='store_true')
|
||||||
parser.add_argument('-s', '--max-steps', default=200)
|
parser.add_argument('-s', '--max-steps', default=200)
|
||||||
|
parser.add_argument('-t', '--max-trials-per-goal', default=4)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
project_path, lean_path = get_project_and_lean_path()
|
project_path, lean_path = get_project_and_lean_path()
|
||||||
|
@ -60,17 +86,20 @@ if __name__ == '__main__':
|
||||||
print(f"$LEAN_PATH: {lean_path}")
|
print(f"$LEAN_PATH: {lean_path}")
|
||||||
|
|
||||||
test_data = read_test_data(args.validation)
|
test_data = read_test_data(args.validation)
|
||||||
server = Server(imports=["Mathlib"], project_path=project_path, lean_path=lean_path)
|
|
||||||
agent = LLMAgent(server, use_hammer=args.use_hammer, use_llm=args.use_llm)
|
|
||||||
for datum in test_data:
|
for datum in test_data:
|
||||||
file_name = output_file_name(datum, args.use_hammer, args.use_llm)
|
file_name = output_file_name(datum, args.use_hammer, args.use_llm)
|
||||||
|
placeholder_file_name = file_name.with_suffix('.placeholder')
|
||||||
if file_name.is_file():
|
if file_name.is_file():
|
||||||
print(f"Skipping {datum['id']}")
|
print(f"Skipping {datum['id']}")
|
||||||
continue
|
continue
|
||||||
result = try_test_data(server, agent, datum, max_steps=args.max_steps)
|
server = Server(imports=["Example"], project_path=project_path, lean_path=lean_path, options=["maxHeartbeats=0"])
|
||||||
|
agent = LLMAgent(server, use_hammer=args.use_hammer, use_llm=args.use_llm)
|
||||||
|
result = try_test_data(server, agent, datum, max_steps=args.max_steps, max_trials_per_goal=args.max_trials_per_goal)
|
||||||
if result is None:
|
if result is None:
|
||||||
with open(file_name + '-placeholder', 'w') as f:
|
with open(placeholder_file_name, 'w') as f:
|
||||||
json.dump({ 'id': datum['id'] }, f)
|
json.dump({ 'id': datum['id'] }, f)
|
||||||
else:
|
else:
|
||||||
|
if placeholder_file_name.is_file():
|
||||||
|
placeholder_file_name.unlink()
|
||||||
with open(file_name, 'w') as f:
|
with open(file_name, 'w') as f:
|
||||||
json.dump({ 'id': datum['id'], 'success': result.success, 'steps': result.steps }, f)
|
json.dump({ 'id': datum['id'], 'success': result.success, 'steps': result.steps }, f)
|
||||||
|
|
|
@ -63,7 +63,7 @@ class Agent:
|
||||||
informal_stmt: str = "",
|
informal_stmt: str = "",
|
||||||
informal_proof: str = "",
|
informal_proof: str = "",
|
||||||
max_steps: int = 100,
|
max_steps: int = 100,
|
||||||
max_trial_per_goal: int = 5,
|
max_trials_per_goal: int = 5,
|
||||||
verbose: bool = False) -> SearchResult:
|
verbose: bool = False) -> SearchResult:
|
||||||
|
|
||||||
search_stack = [SearchState(state=server.goal_start(target),
|
search_stack = [SearchState(state=server.goal_start(target),
|
||||||
|
@ -99,7 +99,7 @@ class Agent:
|
||||||
# Find the unsolved goal with the highest priority
|
# Find the unsolved goal with the highest priority
|
||||||
goal_id = search_state.next_goal_id
|
goal_id = search_state.next_goal_id
|
||||||
|
|
||||||
if search_state.trials[goal_id] > max_trial_per_goal:
|
if search_state.trials[goal_id] > max_trials_per_goal:
|
||||||
# force halt the search
|
# force halt the search
|
||||||
tactic = None
|
tactic = None
|
||||||
else:
|
else:
|
||||||
|
@ -118,6 +118,7 @@ class Agent:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
search_state.trials[goal_id] += 1
|
||||||
state = search_state.state
|
state = search_state.state
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"{state.state_id}.{goal_id}: {tactic} on {search_state.state.goals[goal_id]}")
|
print(f"{state.state_id}.{goal_id}: {tactic} on {search_state.state.goals[goal_id]}")
|
||||||
|
|
|
@ -13,6 +13,8 @@ class LLMAgent(Agent):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.n_trials = 5
|
self.n_trials = 5
|
||||||
self.server = server
|
self.server = server
|
||||||
|
|
||||||
|
if use_llm:
|
||||||
sgl.set_default_backend(sgl.OpenAI("gpt-4"))
|
sgl.set_default_backend(sgl.OpenAI("gpt-4"))
|
||||||
|
|
||||||
self.goal_tactic_id_map = collections.defaultdict(lambda : 0)
|
self.goal_tactic_id_map = collections.defaultdict(lambda : 0)
|
||||||
|
@ -37,6 +39,7 @@ class LLMAgent(Agent):
|
||||||
if i >= len(self.tactics) and not self.use_llm:
|
if i >= len(self.tactics) and not self.use_llm:
|
||||||
return None
|
return None
|
||||||
elif i >= len(self.tactics):
|
elif i >= len(self.tactics):
|
||||||
|
assert self.use_llm
|
||||||
new_state = None
|
new_state = None
|
||||||
for ii in range(self.n_trials):
|
for ii in range(self.n_trials):
|
||||||
print(f"===============trail {str(ii)}============")
|
print(f"===============trail {str(ii)}============")
|
||||||
|
|
Loading…
Reference in New Issue