feat: Add ablation testing

This commit is contained in:
Leni Aniva 2024-06-05 14:19:18 -07:00
parent 4e678c7b97
commit ce633fecda
Signed by: aniva
GPG Key ID: 4D9B1C8D10EA4C50
3 changed files with 55 additions and 17 deletions

View File

@ -1,8 +1,9 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import subprocess, json import subprocess, json, argparse
from pathlib import Path from pathlib import Path
from pantograph.server import Server from pantograph.server import Server
from pantograph.search import SearchResult
from pantograph.search_llm import LLMAgent from pantograph.search_llm import LLMAgent
def get_project_and_lean_path(): def get_project_and_lean_path():
@ -15,7 +16,7 @@ def read_test_data():
with open(jsonl_path, 'r') as f: with open(jsonl_path, 'r') as f:
return [json.loads(l) for l in list(f)] return [json.loads(l) for l in list(f)]
def try_test_data(server, agent, entry) -> bool: def try_test_data(server, agent, entry: dict, max_steps: int) -> SearchResult:
e = entry["formal_statement"] e = entry["formal_statement"]
informal_stmt = entry["informal_stmt"] informal_stmt = entry["informal_stmt"]
informal_proof = entry["informal_proof"] informal_proof = entry["informal_proof"]
@ -24,14 +25,37 @@ def try_test_data(server, agent, entry) -> bool:
target = "forall " + ','.join(e.rsplit(':', 1)) target = "forall " + ','.join(e.rsplit(':', 1))
print(f"Target: {target}") print(f"Target: {target}")
agent = LLMAgent(server) agent = LLMAgent(server)
return agent.search(server=server, target=target, informal_stmt = informal_stmt, informal_proof = informal_proof,verbose=True) return agent.search(server=server, target=target, informal_stmt = informal_stmt, informal_proof = informal_proof,verbose=True, max_steps=max_steps)
def output_file_name(datum, use_hammer: bool, use_llm: bool):
name = datum["id"]
folder = 'output'
if use_hammer:
folder += '-hammer'
if use_llm:
folder += '-llm'
folder = Path(__file__).parent / folder
folder.mkdir(exist_ok=True, parents=True)
return folder / f"{name}.json"
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(
prog='MiniF2F Search',
description='Executes LLM on MiniF2F Search')
parser.add_argument('--use-hammer', action='store_true')
parser.add_argument('--use-llm', action='store_true')
parser.add_argument('-s', '--max-steps', default=1000)
args = parser.parse_args()
project_path, lean_path = get_project_and_lean_path() project_path, lean_path = get_project_and_lean_path()
print(f"$PWD: {project_path}") print(f"$PWD: {project_path}")
print(f"$LEAN_PATH: {lean_path}") print(f"$LEAN_PATH: {lean_path}")
test_data = read_test_data() test_data = read_test_data()
server = Server(imports=["Mathlib"], project_path=project_path, lean_path=lean_path) server = Server(imports=["Mathlib"], project_path=project_path, lean_path=lean_path)
agent = LLMAgent(server) agent = LLMAgent(server, use_hammer=args.use_hammer, use_llm=args.use_llm)
try_test_data(server, agent, test_data[0]) for datum in test_data[:1]:
result = try_test_data(server, agent, datum, max_steps=args.max_steps)
file_name = output_file_name(datum, args.use_hammer, args.use_llm)
with open(file_name, 'w') as f:
json.dump({ 'id': datum['id'], 'success': result.success, 'steps': result.steps }, f)

View File

@ -26,6 +26,11 @@ class SearchState:
def is_solved(self) -> bool: def is_solved(self) -> bool:
return all(self.solved) return all(self.solved)
@dataclass(frozen=True)
class SearchResult:
success: bool
steps: int
class Agent: class Agent:
@ -51,7 +56,7 @@ class Agent:
informal_stmt: str = "", informal_stmt: str = "",
informal_proof: str = "", informal_proof: str = "",
max_steps: int = 1000, max_steps: int = 1000,
verbose: bool = False) -> bool: verbose: bool = False) -> SearchResult:
search_stack = [SearchState(state=server.goal_start(target), search_stack = [SearchState(state=server.goal_start(target),
parent=None, parent=None,
@ -76,7 +81,7 @@ class Agent:
if verbose: if verbose:
print("Search complete: Root state solved") print("Search complete: Root state solved")
self.reset() self.reset()
return True return SearchResult(success=True, steps=i_step)
search_stack.pop(-1) search_stack.pop(-1)
assert not search_stack[search_state.parent].solved[search_state.parent_goal_id] assert not search_stack[search_state.parent].solved[search_state.parent_goal_id]
@ -97,7 +102,7 @@ class Agent:
if verbose: if verbose:
print("Tactic list has been exhausted") print("Tactic list has been exhausted")
self.reset() self.reset()
return False return SearchResult(success=False, steps=i_step)
continue continue
try: try:
@ -123,7 +128,7 @@ class Agent:
print("Search iteration limit exhausted") print("Search iteration limit exhausted")
self.reset() self.reset()
return False return SearchResult(success=False, steps=max_steps)
class DumbAgent(Agent): class DumbAgent(Agent):

View File

@ -8,26 +8,35 @@ import sglang as sgl
class LLMAgent(Agent): class LLMAgent(Agent):
def __init__(self, server): def __init__(self, server,
use_hammer=True, use_llm=True):
super().__init__() super().__init__()
self.n_trials = 5 self.n_trials = 5
self.server = server self.server = server
sgl.set_default_backend(sgl.OpenAI("gpt-4")) sgl.set_default_backend(sgl.OpenAI("gpt-4"))
self.goal_tactic_id_map = collections.defaultdict(lambda : 0) self.goal_tactic_id_map = collections.defaultdict(lambda : 0)
self.use_hammer = use_hammer
self.use_llm = use_llm
if use_hammer:
self.tactics = [ self.tactics = [
"aesop", "aesop",
#"simp", #"simp",
#"rfl", #"rfl",
#"decide", #"decide",
] ]
else:
self.tactics = []
def next_tactic(self, state: GoalState, goal_id: int, informal_stmt:str="", informal_proof:str="") -> Optional[Tactic]: def next_tactic(self, state: GoalState, goal_id: int, informal_stmt:str="", informal_proof:str="") -> Optional[Tactic]:
key = (state.state_id, goal_id) key = (state.state_id, goal_id)
i = self.goal_tactic_id_map[key] i = self.goal_tactic_id_map[key]
target = state.goals[goal_id].target target = state.goals[goal_id].target
if i >= len(self.tactics): if i >= len(self.tactics) and not self.use_llm:
return None
elif i >= len(self.tactics):
new_state = None new_state = None
for ii in range(self.n_trials): for ii in range(self.n_trials):
print(f"===============trail {str(ii)}============") print(f"===============trail {str(ii)}============")