search llm passed
This commit is contained in:
parent
2d12e87126
commit
0bb8d55de2
|
@ -1,4 +1,4 @@
|
||||||
from pantograph.server import Server, ServerError
|
from pantograph.server import Server, ServerError, TacticFailure
|
||||||
from pantograph.expr import Variable, Goal, TacticCalc
|
from pantograph.expr import Variable, Goal, TacticCalc
|
||||||
import unittest
|
import unittest
|
||||||
import sglang as sgl
|
import sglang as sgl
|
||||||
|
@ -105,19 +105,24 @@ def select_tactic(s, server, state, goal_id, feedback_turns = 5):
|
||||||
tactic = extract_code_from_llm_output(tmp["tactic"])
|
tactic = extract_code_from_llm_output(tmp["tactic"])
|
||||||
s += sgl.assistant("```"+tactic+"```")
|
s += sgl.assistant("```"+tactic+"```")
|
||||||
success, new_state = apply_tactic(server, state, goal_id, tactic)
|
success, new_state = apply_tactic(server, state, goal_id, tactic)
|
||||||
|
print("===execute===")
|
||||||
|
print(success, new_state )
|
||||||
if not success:
|
if not success:
|
||||||
with s.user():
|
with s.user():
|
||||||
s += "This answer got Lean compile error:\n" + str(new_state) + "\n"
|
s += "This answer got Lean compile error:\n" + str(new_state) + "\n"
|
||||||
s += "Please try again by taking the Lean compiler feedback."
|
s += "Please try again by taking the Lean compiler feedback."
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return new_state
|
return tactic, new_state
|
||||||
|
return None, None
|
||||||
|
|
||||||
def apply_tactic(server, state, goal_id, tactic):
|
def apply_tactic(server, state, goal_id, tactic):
|
||||||
try:
|
try:
|
||||||
new_state = server.goal_tactic(state, goal_id=goal_id, tactic=tactic)
|
new_state = server.goal_tactic(state, goal_id=goal_id, tactic=tactic)
|
||||||
except ServerError as e:
|
except ServerError as e:
|
||||||
return False, e
|
return False, e
|
||||||
|
except TacticFailure as e:
|
||||||
|
return False, e
|
||||||
return True, new_state
|
return True, new_state
|
||||||
|
|
||||||
def extract_code_from_llm_output(reply):
|
def extract_code_from_llm_output(reply):
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import override, Optional
|
from typing import Optional
|
||||||
import collections, unittest
|
import collections, unittest
|
||||||
|
|
||||||
from pantograph.server import Server, TacticFailure
|
from pantograph.server import Server, TacticFailure
|
||||||
|
@ -140,7 +140,6 @@ class DumbAgent(Agent):
|
||||||
"assumption",
|
"assumption",
|
||||||
]
|
]
|
||||||
|
|
||||||
@override
|
|
||||||
def next_tactic(self, state: GoalState, goal_id: int) -> Optional[Tactic]:
|
def next_tactic(self, state: GoalState, goal_id: int) -> Optional[Tactic]:
|
||||||
key = (state.state_id, goal_id)
|
key = (state.state_id, goal_id)
|
||||||
i = self.goal_tactic_id_map[key]
|
i = self.goal_tactic_id_map[key]
|
||||||
|
|
|
@ -1,17 +1,18 @@
|
||||||
import search
|
from typing import Optional
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import override, Optional
|
|
||||||
import collections, unittest
|
import collections, unittest
|
||||||
|
from pantograph.search import Agent
|
||||||
from pantograph.server import Server, TacticFailure, ServerError
|
from pantograph.server import Server, TacticFailure, ServerError
|
||||||
from pantograph.expr import Expr, Tactic, GoalState
|
from pantograph.expr import Expr, Tactic, GoalState
|
||||||
from pantograph.gen_tactic import LEAN4_REWRITE, select_tactic
|
from pantograph.gen_tactic import LEAN4_REWRITE, select_tactic
|
||||||
|
import sglang as sgl
|
||||||
|
|
||||||
class LLMAgent(search.Agent):
|
class LLMAgent(Agent):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, server):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.n_trials = 5
|
self.n_trials = 5
|
||||||
|
self.server = server
|
||||||
|
sgl.set_default_backend(sgl.OpenAI("gpt-4"))
|
||||||
|
|
||||||
self.goal_tactic_id_map = collections.defaultdict(lambda : 0)
|
self.goal_tactic_id_map = collections.defaultdict(lambda : 0)
|
||||||
self.intros = [
|
self.intros = [
|
||||||
|
@ -27,7 +28,6 @@ class LLMAgent(search.Agent):
|
||||||
"assumption",
|
"assumption",
|
||||||
]
|
]
|
||||||
|
|
||||||
@override
|
|
||||||
def next_tactic(self, state: GoalState, goal_id: int) -> Optional[Tactic]:
|
def next_tactic(self, state: GoalState, goal_id: int) -> Optional[Tactic]:
|
||||||
key = (state.state_id, goal_id)
|
key = (state.state_id, goal_id)
|
||||||
i = self.goal_tactic_id_map[key]
|
i = self.goal_tactic_id_map[key]
|
||||||
|
@ -45,19 +45,44 @@ class LLMAgent(search.Agent):
|
||||||
|
|
||||||
self.goal_tactic_id_map[key] = i + 1
|
self.goal_tactic_id_map[key] = i + 1
|
||||||
new_state = None
|
new_state = None
|
||||||
for i in range(self.n_trails):
|
for ii in range(self.n_trials):
|
||||||
print(f"===============trail {str(i)}============")
|
print(f"===============trail {str(ii)}============")
|
||||||
try:
|
try:
|
||||||
state = select_tactic.run(self.server, state, goal_id = 1)
|
state = select_tactic.run(server = self.server, state=state, goal_id = goal_id)
|
||||||
tactic, new_state = state.ret_value
|
tactic, new_state = state.ret_value
|
||||||
for m in state.messages():
|
for m in state.messages():
|
||||||
print(m["role"], ":", m["content"])
|
print(m["role"], ":", m["content"])
|
||||||
|
|
||||||
print("\n-- new state --\n", new_state)
|
print("\n-- new state --\n", new_state)
|
||||||
break
|
if tactic:
|
||||||
|
return tactic
|
||||||
|
|
||||||
except ServerError as e:
|
except ServerError as e:
|
||||||
print(f"server error: {e}")
|
print(f"server error: {e}")
|
||||||
continue
|
continue
|
||||||
|
except TacticFailure as e:
|
||||||
|
print(f"tactic failure: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
return tactics[i]
|
return tactics[i]
|
||||||
|
|
||||||
|
class TestSearch(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_solve(self):
|
||||||
|
|
||||||
|
server = Server()
|
||||||
|
agent = LLMAgent(server)
|
||||||
|
flag = agent.search(server=server, target="∀ (p q: Prop), p -> p", verbose=True)
|
||||||
|
#flag = agent.search(server=server, target="∀ (p q: Prop), Or p q -> Or q p", verbose=True)
|
||||||
|
self.assertTrue(flag)
|
||||||
|
def test_solve_big(self):
|
||||||
|
|
||||||
|
server = Server()
|
||||||
|
agent = LLMAgent(server)
|
||||||
|
flag = agent.search(server=server, target="∀ (p q: Prop), Or p q -> Or q p", verbose=True)
|
||||||
|
self.assertTrue(flag)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
|
|
Loading…
Reference in New Issue