search llm passed

This commit is contained in:
Chuyue Sun 2024-06-04 19:59:59 -07:00
parent 2d12e87126
commit 0bb8d55de2
3 changed files with 44 additions and 15 deletions

View File

@ -1,4 +1,4 @@
from pantograph.server import Server, ServerError
from pantograph.server import Server, ServerError, TacticFailure
from pantograph.expr import Variable, Goal, TacticCalc
import unittest
import sglang as sgl
@ -105,19 +105,24 @@ def select_tactic(s, server, state, goal_id, feedback_turns = 5):
tactic = extract_code_from_llm_output(tmp["tactic"])
s += sgl.assistant("```"+tactic+"```")
success, new_state = apply_tactic(server, state, goal_id, tactic)
print("===execute===")
print(success, new_state )
if not success:
with s.user():
s += "This answer got Lean compile error:\n" + str(new_state) + "\n"
s += "Please try again by taking the Lean compiler feedback."
else:
return new_state
return tactic, new_state
return None, None
def apply_tactic(server, state, goal_id, tactic):
try:
new_state = server.goal_tactic(state, goal_id=goal_id, tactic=tactic)
except ServerError as e:
return False, e
except TacticFailure as e:
return False, e
return True, new_state
def extract_code_from_llm_output(reply):

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import override, Optional
from typing import Optional
import collections, unittest
from pantograph.server import Server, TacticFailure
@ -140,7 +140,6 @@ class DumbAgent(Agent):
"assumption",
]
@override
def next_tactic(self, state: GoalState, goal_id: int) -> Optional[Tactic]:
key = (state.state_id, goal_id)
i = self.goal_tactic_id_map[key]

View File

@ -1,17 +1,18 @@
import search
from dataclasses import dataclass
from typing import override, Optional
from typing import Optional
import collections, unittest
from pantograph.search import Agent
from pantograph.server import Server, TacticFailure, ServerError
from pantograph.expr import Expr, Tactic, GoalState
from pantograph.gen_tactic import LEAN4_REWRITE, select_tactic
import sglang as sgl
class LLMAgent(search.Agent):
class LLMAgent(Agent):
def __init__(self):
def __init__(self, server):
super().__init__()
self.n_trials = 5
self.server = server
sgl.set_default_backend(sgl.OpenAI("gpt-4"))
self.goal_tactic_id_map = collections.defaultdict(lambda : 0)
self.intros = [
@ -27,7 +28,6 @@ class LLMAgent(search.Agent):
"assumption",
]
@override
def next_tactic(self, state: GoalState, goal_id: int) -> Optional[Tactic]:
key = (state.state_id, goal_id)
i = self.goal_tactic_id_map[key]
@ -45,19 +45,44 @@ class LLMAgent(search.Agent):
self.goal_tactic_id_map[key] = i + 1
new_state = None
for i in range(self.n_trails):
print(f"===============trail {str(i)}============")
for ii in range(self.n_trials):
print(f"===============trail {str(ii)}============")
try:
state = select_tactic.run(self.server, state, goal_id = 1)
state = select_tactic.run(server = self.server, state=state, goal_id = goal_id)
tactic, new_state = state.ret_value
for m in state.messages():
print(m["role"], ":", m["content"])
print("\n-- new state --\n", new_state)
break
if tactic:
return tactic
except ServerError as e:
print(f"server error: {e}")
continue
except TacticFailure as e:
print(f"tactic failure: {e}")
continue
return tactics[i]
class TestSearch(unittest.TestCase):
def test_solve(self):
server = Server()
agent = LLMAgent(server)
flag = agent.search(server=server, target="∀ (p q: Prop), p -> p", verbose=True)
#flag = agent.search(server=server, target="∀ (p q: Prop), Or p q -> Or q p", verbose=True)
self.assertTrue(flag)
def test_solve_big(self):
server = Server()
agent = LLMAgent(server)
flag = agent.search(server=server, target="∀ (p q: Prop), Or p q -> Or q p", verbose=True)
self.assertTrue(flag)
if __name__ == '__main__':
unittest.main()