Pantograph/pantograph/search_llm.py

80 lines
2.4 KiB
Python
Raw Normal View History

2024-06-04 19:59:59 -07:00
from typing import Optional
2024-06-04 19:29:21 -07:00
import collections, unittest
2024-06-04 19:59:59 -07:00
from pantograph.search import Agent
2024-06-04 19:29:21 -07:00
from pantograph.server import Server, TacticFailure, ServerError
from pantograph.expr import Expr, Tactic, GoalState
from pantograph.gen_tactic import LEAN4_REWRITE, select_tactic
2024-06-04 19:59:59 -07:00
import sglang as sgl
2024-06-04 19:29:21 -07:00
2024-06-04 19:59:59 -07:00
class LLMAgent(Agent):
2024-06-04 19:29:21 -07:00
2024-06-04 19:59:59 -07:00
def __init__(self, server):
2024-06-04 19:29:21 -07:00
super().__init__()
self.n_trials = 5
2024-06-04 19:59:59 -07:00
self.server = server
sgl.set_default_backend(sgl.OpenAI("gpt-4"))
2024-06-04 19:29:21 -07:00
self.goal_tactic_id_map = collections.defaultdict(lambda : 0)
self.intros = [
"intro",
]
self.tactics = [
"intro h",
"cases h",
"apply Or.inl",
"apply Or.inr",
]
self.no_space_tactics = [
"assumption",
]
def next_tactic(self, state: GoalState, goal_id: int) -> Optional[Tactic]:
key = (state.state_id, goal_id)
i = self.goal_tactic_id_map[key]
target = state.goals[goal_id].target
if target.startswith(''):
tactics = self.intros
elif ' ' in target:
tactics = self.tactics
else:
tactics = self.no_space_tactics
if i >= len(tactics):
return None
self.goal_tactic_id_map[key] = i + 1
new_state = None
2024-06-04 19:59:59 -07:00
for ii in range(self.n_trials):
print(f"===============trail {str(ii)}============")
2024-06-04 20:37:36 -07:00
s = select_tactic.run(server = self.server, state=state, goal_id = goal_id)
tactic, new_state = s.ret_value
for m in s.messages():
print(m["role"], ":", m["content"])
2024-06-04 19:29:21 -07:00
2024-06-04 20:37:36 -07:00
print("\n-- new state --\n", new_state)
if tactic:
return tactic
2024-06-04 19:29:21 -07:00
return tactics[i]
2024-06-04 19:59:59 -07:00
class TestSearch(unittest.TestCase):
def test_solve(self):
server = Server()
agent = LLMAgent(server)
flag = agent.search(server=server, target="∀ (p q: Prop), p -> p", verbose=True)
#flag = agent.search(server=server, target="∀ (p q: Prop), Or p q -> Or q p", verbose=True)
self.assertTrue(flag)
def test_solve_big(self):
server = Server()
agent = LLMAgent(server)
flag = agent.search(server=server, target="∀ (p q: Prop), Or p q -> Or q p", verbose=True)
self.assertTrue(flag)
if __name__ == '__main__':
unittest.main()