feat: Add ablation testing
This commit is contained in:
parent
4e678c7b97
commit
ce633fecda
|
@ -1,8 +1,9 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
import subprocess, json
|
import subprocess, json, argparse
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pantograph.server import Server
|
from pantograph.server import Server
|
||||||
|
from pantograph.search import SearchResult
|
||||||
from pantograph.search_llm import LLMAgent
|
from pantograph.search_llm import LLMAgent
|
||||||
|
|
||||||
def get_project_and_lean_path():
|
def get_project_and_lean_path():
|
||||||
|
@ -15,7 +16,7 @@ def read_test_data():
|
||||||
with open(jsonl_path, 'r') as f:
|
with open(jsonl_path, 'r') as f:
|
||||||
return [json.loads(l) for l in list(f)]
|
return [json.loads(l) for l in list(f)]
|
||||||
|
|
||||||
def try_test_data(server, agent, entry) -> bool:
|
def try_test_data(server, agent, entry: dict, max_steps: int) -> SearchResult:
|
||||||
e = entry["formal_statement"]
|
e = entry["formal_statement"]
|
||||||
informal_stmt = entry["informal_stmt"]
|
informal_stmt = entry["informal_stmt"]
|
||||||
informal_proof = entry["informal_proof"]
|
informal_proof = entry["informal_proof"]
|
||||||
|
@ -24,14 +25,37 @@ def try_test_data(server, agent, entry) -> bool:
|
||||||
target = "forall " + ','.join(e.rsplit(':', 1))
|
target = "forall " + ','.join(e.rsplit(':', 1))
|
||||||
print(f"Target: {target}")
|
print(f"Target: {target}")
|
||||||
agent = LLMAgent(server)
|
agent = LLMAgent(server)
|
||||||
return agent.search(server=server, target=target, informal_stmt = informal_stmt, informal_proof = informal_proof,verbose=True)
|
return agent.search(server=server, target=target, informal_stmt = informal_stmt, informal_proof = informal_proof,verbose=True, max_steps=max_steps)
|
||||||
|
|
||||||
|
def output_file_name(datum, use_hammer: bool, use_llm: bool):
|
||||||
|
name = datum["id"]
|
||||||
|
folder = 'output'
|
||||||
|
if use_hammer:
|
||||||
|
folder += '-hammer'
|
||||||
|
if use_llm:
|
||||||
|
folder += '-llm'
|
||||||
|
folder = Path(__file__).parent / folder
|
||||||
|
folder.mkdir(exist_ok=True, parents=True)
|
||||||
|
return folder / f"{name}.json"
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
prog='MiniF2F Search',
|
||||||
|
description='Executes LLM on MiniF2F Search')
|
||||||
|
parser.add_argument('--use-hammer', action='store_true')
|
||||||
|
parser.add_argument('--use-llm', action='store_true')
|
||||||
|
parser.add_argument('-s', '--max-steps', default=1000)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
project_path, lean_path = get_project_and_lean_path()
|
project_path, lean_path = get_project_and_lean_path()
|
||||||
print(f"$PWD: {project_path}")
|
print(f"$PWD: {project_path}")
|
||||||
print(f"$LEAN_PATH: {lean_path}")
|
print(f"$LEAN_PATH: {lean_path}")
|
||||||
|
|
||||||
test_data = read_test_data()
|
test_data = read_test_data()
|
||||||
server = Server(imports=["Mathlib"], project_path=project_path, lean_path=lean_path)
|
server = Server(imports=["Mathlib"], project_path=project_path, lean_path=lean_path)
|
||||||
agent = LLMAgent(server)
|
agent = LLMAgent(server, use_hammer=args.use_hammer, use_llm=args.use_llm)
|
||||||
try_test_data(server, agent, test_data[0])
|
for datum in test_data[:1]:
|
||||||
|
result = try_test_data(server, agent, datum, max_steps=args.max_steps)
|
||||||
|
file_name = output_file_name(datum, args.use_hammer, args.use_llm)
|
||||||
|
with open(file_name, 'w') as f:
|
||||||
|
json.dump({ 'id': datum['id'], 'success': result.success, 'steps': result.steps }, f)
|
||||||
|
|
|
@ -26,6 +26,11 @@ class SearchState:
|
||||||
def is_solved(self) -> bool:
|
def is_solved(self) -> bool:
|
||||||
return all(self.solved)
|
return all(self.solved)
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class SearchResult:
|
||||||
|
|
||||||
|
success: bool
|
||||||
|
steps: int
|
||||||
|
|
||||||
class Agent:
|
class Agent:
|
||||||
|
|
||||||
|
@ -51,7 +56,7 @@ class Agent:
|
||||||
informal_stmt: str = "",
|
informal_stmt: str = "",
|
||||||
informal_proof: str = "",
|
informal_proof: str = "",
|
||||||
max_steps: int = 1000,
|
max_steps: int = 1000,
|
||||||
verbose: bool = False) -> bool:
|
verbose: bool = False) -> SearchResult:
|
||||||
|
|
||||||
search_stack = [SearchState(state=server.goal_start(target),
|
search_stack = [SearchState(state=server.goal_start(target),
|
||||||
parent=None,
|
parent=None,
|
||||||
|
@ -76,7 +81,7 @@ class Agent:
|
||||||
if verbose:
|
if verbose:
|
||||||
print("Search complete: Root state solved")
|
print("Search complete: Root state solved")
|
||||||
self.reset()
|
self.reset()
|
||||||
return True
|
return SearchResult(success=True, steps=i_step)
|
||||||
|
|
||||||
search_stack.pop(-1)
|
search_stack.pop(-1)
|
||||||
assert not search_stack[search_state.parent].solved[search_state.parent_goal_id]
|
assert not search_stack[search_state.parent].solved[search_state.parent_goal_id]
|
||||||
|
@ -97,7 +102,7 @@ class Agent:
|
||||||
if verbose:
|
if verbose:
|
||||||
print("Tactic list has been exhausted")
|
print("Tactic list has been exhausted")
|
||||||
self.reset()
|
self.reset()
|
||||||
return False
|
return SearchResult(success=False, steps=i_step)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -123,7 +128,7 @@ class Agent:
|
||||||
print("Search iteration limit exhausted")
|
print("Search iteration limit exhausted")
|
||||||
|
|
||||||
self.reset()
|
self.reset()
|
||||||
return False
|
return SearchResult(success=False, steps=max_steps)
|
||||||
|
|
||||||
|
|
||||||
class DumbAgent(Agent):
|
class DumbAgent(Agent):
|
||||||
|
|
|
@ -8,26 +8,35 @@ import sglang as sgl
|
||||||
|
|
||||||
class LLMAgent(Agent):
|
class LLMAgent(Agent):
|
||||||
|
|
||||||
def __init__(self, server):
|
def __init__(self, server,
|
||||||
|
use_hammer=True, use_llm=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.n_trials = 5
|
self.n_trials = 5
|
||||||
self.server = server
|
self.server = server
|
||||||
sgl.set_default_backend(sgl.OpenAI("gpt-4"))
|
sgl.set_default_backend(sgl.OpenAI("gpt-4"))
|
||||||
|
|
||||||
self.goal_tactic_id_map = collections.defaultdict(lambda : 0)
|
self.goal_tactic_id_map = collections.defaultdict(lambda : 0)
|
||||||
|
|
||||||
|
self.use_hammer = use_hammer
|
||||||
|
self.use_llm = use_llm
|
||||||
|
if use_hammer:
|
||||||
self.tactics = [
|
self.tactics = [
|
||||||
"aesop",
|
"aesop",
|
||||||
#"simp",
|
#"simp",
|
||||||
#"rfl",
|
#"rfl",
|
||||||
#"decide",
|
#"decide",
|
||||||
]
|
]
|
||||||
|
else:
|
||||||
|
self.tactics = []
|
||||||
|
|
||||||
def next_tactic(self, state: GoalState, goal_id: int, informal_stmt:str="", informal_proof:str="") -> Optional[Tactic]:
|
def next_tactic(self, state: GoalState, goal_id: int, informal_stmt:str="", informal_proof:str="") -> Optional[Tactic]:
|
||||||
key = (state.state_id, goal_id)
|
key = (state.state_id, goal_id)
|
||||||
i = self.goal_tactic_id_map[key]
|
i = self.goal_tactic_id_map[key]
|
||||||
|
|
||||||
target = state.goals[goal_id].target
|
target = state.goals[goal_id].target
|
||||||
if i >= len(self.tactics):
|
if i >= len(self.tactics) and not self.use_llm:
|
||||||
|
return None
|
||||||
|
elif i >= len(self.tactics):
|
||||||
new_state = None
|
new_state = None
|
||||||
for ii in range(self.n_trials):
|
for ii in range(self.n_trials):
|
||||||
print(f"===============trail {str(ii)}============")
|
print(f"===============trail {str(ii)}============")
|
||||||
|
|
Loading…
Reference in New Issue