search llm passed

This commit is contained in:
Chuyue Sun 2024-06-04 19:59:59 -07:00
parent 2d12e87126
commit 0bb8d55de2
3 changed files with 44 additions and 15 deletions

View File

@ -1,4 +1,4 @@
from pantograph.server import Server, ServerError from pantograph.server import Server, ServerError, TacticFailure
from pantograph.expr import Variable, Goal, TacticCalc from pantograph.expr import Variable, Goal, TacticCalc
import unittest import unittest
import sglang as sgl import sglang as sgl
@ -105,19 +105,24 @@ def select_tactic(s, server, state, goal_id, feedback_turns = 5):
tactic = extract_code_from_llm_output(tmp["tactic"]) tactic = extract_code_from_llm_output(tmp["tactic"])
s += sgl.assistant("```"+tactic+"```") s += sgl.assistant("```"+tactic+"```")
success, new_state = apply_tactic(server, state, goal_id, tactic) success, new_state = apply_tactic(server, state, goal_id, tactic)
print("===execute===")
print(success, new_state )
if not success: if not success:
with s.user(): with s.user():
s += "This answer got Lean compile error:\n" + str(new_state) + "\n" s += "This answer got Lean compile error:\n" + str(new_state) + "\n"
s += "Please try again by taking the Lean compiler feedback." s += "Please try again by taking the Lean compiler feedback."
else: else:
return new_state return tactic, new_state
return None, None
def apply_tactic(server, state, goal_id, tactic): def apply_tactic(server, state, goal_id, tactic):
try: try:
new_state = server.goal_tactic(state, goal_id=goal_id, tactic=tactic) new_state = server.goal_tactic(state, goal_id=goal_id, tactic=tactic)
except ServerError as e: except ServerError as e:
return False, e return False, e
except TacticFailure as e:
return False, e
return True, new_state return True, new_state
def extract_code_from_llm_output(reply): def extract_code_from_llm_output(reply):

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import override, Optional from typing import Optional
import collections, unittest import collections, unittest
from pantograph.server import Server, TacticFailure from pantograph.server import Server, TacticFailure
@ -140,7 +140,6 @@ class DumbAgent(Agent):
"assumption", "assumption",
] ]
@override
def next_tactic(self, state: GoalState, goal_id: int) -> Optional[Tactic]: def next_tactic(self, state: GoalState, goal_id: int) -> Optional[Tactic]:
key = (state.state_id, goal_id) key = (state.state_id, goal_id)
i = self.goal_tactic_id_map[key] i = self.goal_tactic_id_map[key]

View File

@ -1,17 +1,18 @@
import search from typing import Optional
from dataclasses import dataclass
from typing import override, Optional
import collections, unittest import collections, unittest
from pantograph.search import Agent
from pantograph.server import Server, TacticFailure, ServerError from pantograph.server import Server, TacticFailure, ServerError
from pantograph.expr import Expr, Tactic, GoalState from pantograph.expr import Expr, Tactic, GoalState
from pantograph.gen_tactic import LEAN4_REWRITE, select_tactic from pantograph.gen_tactic import LEAN4_REWRITE, select_tactic
import sglang as sgl
class LLMAgent(search.Agent): class LLMAgent(Agent):
def __init__(self): def __init__(self, server):
super().__init__() super().__init__()
self.n_trials = 5 self.n_trials = 5
self.server = server
sgl.set_default_backend(sgl.OpenAI("gpt-4"))
self.goal_tactic_id_map = collections.defaultdict(lambda : 0) self.goal_tactic_id_map = collections.defaultdict(lambda : 0)
self.intros = [ self.intros = [
@ -27,7 +28,6 @@ class LLMAgent(search.Agent):
"assumption", "assumption",
] ]
@override
def next_tactic(self, state: GoalState, goal_id: int) -> Optional[Tactic]: def next_tactic(self, state: GoalState, goal_id: int) -> Optional[Tactic]:
key = (state.state_id, goal_id) key = (state.state_id, goal_id)
i = self.goal_tactic_id_map[key] i = self.goal_tactic_id_map[key]
@ -45,19 +45,44 @@ class LLMAgent(search.Agent):
self.goal_tactic_id_map[key] = i + 1 self.goal_tactic_id_map[key] = i + 1
new_state = None new_state = None
for i in range(self.n_trails): for ii in range(self.n_trials):
print(f"===============trail {str(i)}============") print(f"===============trail {str(ii)}============")
try: try:
state = select_tactic.run(self.server, state, goal_id = 1) state = select_tactic.run(server = self.server, state=state, goal_id = goal_id)
tactic, new_state = state.ret_value tactic, new_state = state.ret_value
for m in state.messages(): for m in state.messages():
print(m["role"], ":", m["content"]) print(m["role"], ":", m["content"])
print("\n-- new state --\n", new_state) print("\n-- new state --\n", new_state)
break if tactic:
return tactic
except ServerError as e: except ServerError as e:
print(f"server error: {e}") print(f"server error: {e}")
continue continue
except TacticFailure as e:
print(f"tactic failure: {e}")
continue
return tactics[i] return tactics[i]
class TestSearch(unittest.TestCase):
def test_solve(self):
server = Server()
agent = LLMAgent(server)
flag = agent.search(server=server, target="∀ (p q: Prop), p -> p", verbose=True)
#flag = agent.search(server=server, target="∀ (p q: Prop), Or p q -> Or q p", verbose=True)
self.assertTrue(flag)
def test_solve_big(self):
server = Server()
agent = LLMAgent(server)
flag = agent.search(server=server, target="∀ (p q: Prop), Or p q -> Or q p", verbose=True)
self.assertTrue(flag)
if __name__ == '__main__':
unittest.main()