update poetry lock; add llm feedback

This commit is contained in:
Chuyue Sun 2024-06-02 14:16:15 -07:00
parent 3d2e737e0c
commit 90a3a7bd3d
2 changed files with 60 additions and 34 deletions

View File

@ -1,4 +1,4 @@
from pantograph.server import Server
from pantograph.server import Server, ServerError
from pantograph.expr import Variable, Goal, TacticCalc
import unittest
import sglang as sgl
@ -16,22 +16,36 @@ def multi_turn_question(s, question_1, question_2):
@sgl.function
def select_tactic(s, state):
def select_tactic(s, server, state, goal_id, n_tries = 5):
s += sgl.system("You are an expert in Lean. Choose the next one tactic to run given the current proof state and goals.")
s += sgl.user("The current proof state: GoalState(state_id=0, goals=[Goal(variables=[], target='∀ (a b: Nat), (b = 2) -> 1 + a + 1 = a + b', name=None, is_conversion=False)])")
s += sgl.assistant("```intros a b h```")
s += sgl.user("The current proof state: GoalState(state_id=1, goals=[Goal(variables=[Variable(t='Nat', v=None, name='a'), Variable(t='Nat', v=None, name='b'), Variable(t='b = 2', v=None, name='h')], target='1 + a + 1 = a + b', name=None, is_conversion=False)])")
s += sgl.assistant('TacticCalc("1 + a + 1 = a + 1 + 1")')
s += sgl.user("The current proof state: " + str(state))
for i in range(n_tries):
with s.copy() as tmp:
tmp += sgl.assistant(sgl.gen("tactic", max_tokens=64))
print("==tmp===")
print(tmp["tactic"])
tactic = extract_code_from_llm_output(tmp["tactic"])
s += sgl.assistant("```"+tactic+"```")
return tactic
success, new_state = apply_tactic(server, state, goal_id, tactic)
if not success:
with s.user():
s += "This answer got Lean compile error:\n" + str(new_state) + "\n"
s += "Please try again by taking the Lean compiler feedback."
else:
return new_state
def apply_tactic(server, state, goal_id, tactic):
try:
new_state = server.goal_tactic(state, goal_id=goal_id, tactic=tactic)
except ServerError as e:
return False, e
return True, new_state
def extract_code_from_llm_output(reply):
i = reply.find("```lean")
@ -51,6 +65,7 @@ def extract_code_from_llm_output(reply):
class TestServerSGL(unittest.TestCase):
def test_conv_calc_sgl(self):
n_trails = 5
sgl.set_default_backend(sgl.OpenAI("gpt-4"))
server = Server()
@ -80,14 +95,21 @@ class TestServerSGL(unittest.TestCase):
target="a + 1 + 1 = a + b",
),
])
state = select_tactic.run(str(state2))
tactic = state.ret_value
state3 = None
for i in range(n_trails):
print(f"===============trail {str(i)}============")
try:
state = select_tactic.run(server, state2, goal_id = 1)
state3 = state.ret_value
for m in state.messages():
print(m["role"], ":", m["content"])
print("\n-- tactic --\n", tactic)
print("\n-- new state --\n", state3)
except ServerError as e:
print(f"server error: {e}")
continue
state3 = server.goal_tactic(state2, goal_id=1, tactic=tactic)
print("==========state3============")
print(state3)
# state4 = server.goal_tactic(state3, goal_id=0, tactic="rw [Nat.add_assoc]")
@ -119,19 +141,19 @@ class TestServerSGL(unittest.TestCase):
# print()
def test_sglang_openai(self):
sgl.set_default_backend(sgl.OpenAI("gpt-4"))
# def test_sglang_openai(self):
# sgl.set_default_backend(sgl.OpenAI("gpt-4"))
print('\n----- Test sglang ---')
state = multi_turn_question.run(
question_1="What is the capital of the United States?",
question_2="List two local attractions.",
)
# print('\n----- Test sglang ---')
# state = multi_turn_question.run(
# question_1="What is the capital of the United States?",
# question_2="List two local attractions.",
# )
for m in state.messages():
print(m["role"], ":", m["content"])
# for m in state.messages():
# print(m["role"], ":", m["content"])
print("\n-- answer_1 --\n", state["answer_1"])
# print("\n-- answer_1 --\n", state["answer_1"])
if __name__ == '__main__':

18
poetry.lock generated
View File

@ -1,10 +1,15 @@
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
[[package]]
name = "pexpect"
version = "4.9.0"
description = "Pexpect allows easy control of interactive console applications."
category = "main"
optional = false
python-versions = "*"
files = [
{file = "pexpect-4.9.0-py2.py3-none-any.whl", hash = "sha256:7236d1e080e4936be2dc3e326cec0af72acf9212a7e1d060210e70a47e253523"},
{file = "pexpect-4.9.0.tar.gz", hash = "sha256:ee7d41123f3c9911050ea2c2dac107568dc43b2d3b0c7557a33212c398ead30f"},
]
[package.dependencies]
ptyprocess = ">=0.5"
@ -13,15 +18,14 @@ ptyprocess = ">=0.5"
name = "ptyprocess"
version = "0.7.0"
description = "Run a subprocess in a pseudo terminal"
category = "main"
optional = false
python-versions = "*"
files = [
{file = "ptyprocess-0.7.0-py2.py3-none-any.whl", hash = "sha256:4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35"},
{file = "ptyprocess-0.7.0.tar.gz", hash = "sha256:5c5d0a3b48ceee0b48485e0c26037c0acd7d29765ca3fbb5cb3831d347423220"},
]
[metadata]
lock-version = "1.1"
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "54cb66612c110a515f024e54d2b2a0af54ffcbe06602ad7f4ea6a446699d419a"
[metadata.files]
pexpect = []
ptyprocess = []