update poetry lock; add llm feedback

This commit is contained in:
Chuyue Sun 2024-06-02 14:16:15 -07:00
parent 3d2e737e0c
commit 90a3a7bd3d
2 changed files with 60 additions and 34 deletions

View File

@ -1,4 +1,4 @@
from pantograph.server import Server from pantograph.server import Server, ServerError
from pantograph.expr import Variable, Goal, TacticCalc from pantograph.expr import Variable, Goal, TacticCalc
import unittest import unittest
import sglang as sgl import sglang as sgl
@ -16,22 +16,36 @@ def multi_turn_question(s, question_1, question_2):
@sgl.function @sgl.function
def select_tactic(s, state): def select_tactic(s, server, state, goal_id, n_tries = 5):
s += sgl.system("You are an expert in Lean. Choose the next one tactic to run given the current proof state and goals.") s += sgl.system("You are an expert in Lean. Choose the next one tactic to run given the current proof state and goals.")
s += sgl.user("The current proof state: GoalState(state_id=0, goals=[Goal(variables=[], target='∀ (a b: Nat), (b = 2) -> 1 + a + 1 = a + b', name=None, is_conversion=False)])") s += sgl.user("The current proof state: GoalState(state_id=0, goals=[Goal(variables=[], target='∀ (a b: Nat), (b = 2) -> 1 + a + 1 = a + b', name=None, is_conversion=False)])")
s += sgl.assistant("```intros a b h```") s += sgl.assistant("```intros a b h```")
s += sgl.user("The current proof state: GoalState(state_id=1, goals=[Goal(variables=[Variable(t='Nat', v=None, name='a'), Variable(t='Nat', v=None, name='b'), Variable(t='b = 2', v=None, name='h')], target='1 + a + 1 = a + b', name=None, is_conversion=False)])") s += sgl.user("The current proof state: GoalState(state_id=1, goals=[Goal(variables=[Variable(t='Nat', v=None, name='a'), Variable(t='Nat', v=None, name='b'), Variable(t='b = 2', v=None, name='h')], target='1 + a + 1 = a + b', name=None, is_conversion=False)])")
s += sgl.assistant('TacticCalc("1 + a + 1 = a + 1 + 1")') s += sgl.assistant('TacticCalc("1 + a + 1 = a + 1 + 1")')
s += sgl.user("The current proof state: " + str(state)) s += sgl.user("The current proof state: " + str(state))
with s.copy() as tmp: for i in range(n_tries):
tmp += sgl.assistant(sgl.gen("tactic", max_tokens=64)) with s.copy() as tmp:
print("==tmp===") tmp += sgl.assistant(sgl.gen("tactic", max_tokens=64))
print(tmp["tactic"]) print("==tmp===")
tactic = extract_code_from_llm_output(tmp["tactic"]) print(tmp["tactic"])
s += sgl.assistant("```"+tactic+"```") tactic = extract_code_from_llm_output(tmp["tactic"])
return tactic s += sgl.assistant("```"+tactic+"```")
success, new_state = apply_tactic(server, state, goal_id, tactic)
if not success:
with s.user():
s += "This answer got Lean compile error:\n" + str(new_state) + "\n"
s += "Please try again by taking the Lean compiler feedback."
else:
return new_state
def apply_tactic(server, state, goal_id, tactic):
try:
new_state = server.goal_tactic(state, goal_id=goal_id, tactic=tactic)
except ServerError as e:
return False, e
return True, new_state
def extract_code_from_llm_output(reply): def extract_code_from_llm_output(reply):
i = reply.find("```lean") i = reply.find("```lean")
@ -51,6 +65,7 @@ def extract_code_from_llm_output(reply):
class TestServerSGL(unittest.TestCase): class TestServerSGL(unittest.TestCase):
def test_conv_calc_sgl(self): def test_conv_calc_sgl(self):
n_trails = 5
sgl.set_default_backend(sgl.OpenAI("gpt-4")) sgl.set_default_backend(sgl.OpenAI("gpt-4"))
server = Server() server = Server()
@ -80,14 +95,21 @@ class TestServerSGL(unittest.TestCase):
target="a + 1 + 1 = a + b", target="a + 1 + 1 = a + b",
), ),
]) ])
state = select_tactic.run(str(state2)) state3 = None
tactic = state.ret_value for i in range(n_trails):
for m in state.messages(): print(f"===============trail {str(i)}============")
print(m["role"], ":", m["content"]) try:
state = select_tactic.run(server, state2, goal_id = 1)
state3 = state.ret_value
for m in state.messages():
print(m["role"], ":", m["content"])
print("\n-- new state --\n", state3)
except ServerError as e:
print(f"server error: {e}")
continue
print("\n-- tactic --\n", tactic)
state3 = server.goal_tactic(state2, goal_id=1, tactic=tactic)
print("==========state3============") print("==========state3============")
print(state3) print(state3)
# state4 = server.goal_tactic(state3, goal_id=0, tactic="rw [Nat.add_assoc]") # state4 = server.goal_tactic(state3, goal_id=0, tactic="rw [Nat.add_assoc]")
@ -119,19 +141,19 @@ class TestServerSGL(unittest.TestCase):
# print() # print()
def test_sglang_openai(self): # def test_sglang_openai(self):
sgl.set_default_backend(sgl.OpenAI("gpt-4")) # sgl.set_default_backend(sgl.OpenAI("gpt-4"))
print('\n----- Test sglang ---') # print('\n----- Test sglang ---')
state = multi_turn_question.run( # state = multi_turn_question.run(
question_1="What is the capital of the United States?", # question_1="What is the capital of the United States?",
question_2="List two local attractions.", # question_2="List two local attractions.",
) # )
for m in state.messages(): # for m in state.messages():
print(m["role"], ":", m["content"]) # print(m["role"], ":", m["content"])
print("\n-- answer_1 --\n", state["answer_1"]) # print("\n-- answer_1 --\n", state["answer_1"])
if __name__ == '__main__': if __name__ == '__main__':

18
poetry.lock generated
View File

@ -1,10 +1,15 @@
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
[[package]] [[package]]
name = "pexpect" name = "pexpect"
version = "4.9.0" version = "4.9.0"
description = "Pexpect allows easy control of interactive console applications." description = "Pexpect allows easy control of interactive console applications."
category = "main"
optional = false optional = false
python-versions = "*" python-versions = "*"
files = [
{file = "pexpect-4.9.0-py2.py3-none-any.whl", hash = "sha256:7236d1e080e4936be2dc3e326cec0af72acf9212a7e1d060210e70a47e253523"},
{file = "pexpect-4.9.0.tar.gz", hash = "sha256:ee7d41123f3c9911050ea2c2dac107568dc43b2d3b0c7557a33212c398ead30f"},
]
[package.dependencies] [package.dependencies]
ptyprocess = ">=0.5" ptyprocess = ">=0.5"
@ -13,15 +18,14 @@ ptyprocess = ">=0.5"
name = "ptyprocess" name = "ptyprocess"
version = "0.7.0" version = "0.7.0"
description = "Run a subprocess in a pseudo terminal" description = "Run a subprocess in a pseudo terminal"
category = "main"
optional = false optional = false
python-versions = "*" python-versions = "*"
files = [
{file = "ptyprocess-0.7.0-py2.py3-none-any.whl", hash = "sha256:4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35"},
{file = "ptyprocess-0.7.0.tar.gz", hash = "sha256:5c5d0a3b48ceee0b48485e0c26037c0acd7d29765ca3fbb5cb3831d347423220"},
]
[metadata] [metadata]
lock-version = "1.1" lock-version = "2.0"
python-versions = "^3.10" python-versions = "^3.10"
content-hash = "54cb66612c110a515f024e54d2b2a0af54ffcbe06602ad7f4ea6a446699d419a" content-hash = "54cb66612c110a515f024e54d2b2a0af54ffcbe06602ad7f4ea6a446699d419a"
[metadata.files]
pexpect = []
ptyprocess = []