From 9993019b69b314deb28f2c8e60bcc6f6a369db68 Mon Sep 17 00:00:00 2001 From: ChuyueSun Date: Sun, 19 May 2024 18:59:03 -0700 Subject: [PATCH 1/7] update readme --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index d44c1c2..f66d9f2 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,10 @@ Python interface to the Pantograph library ## Getting started +Update submodule +``` bash +git submodule update --init +``` Execute ```bash From fa7c4a89d34a864bc159240a3b53afca26eb1f62 Mon Sep 17 00:00:00 2001 From: ChuyueSun Date: Sun, 19 May 2024 19:16:05 -0700 Subject: [PATCH 2/7] add test_sglang --- test_sglang.py | 57 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 test_sglang.py diff --git a/test_sglang.py b/test_sglang.py new file mode 100644 index 0000000..34b8148 --- /dev/null +++ b/test_sglang.py @@ -0,0 +1,57 @@ +# copy pasted from https://docs.vllm.ai/en/latest/getting_started/quickstart.html + +# do export VLLM_USE_MODELSCOPE=True +import argparse +from typing import Dict, List +import os +import sglang as sgl +from sglang import OpenAI, assistant, gen, set_default_backend, system, user + + +def test_pytorch(): + print('\n----- Test PyTorch ---') + # Print the PyTorch version and CUDA version + print(f"PyTorch version: {torch.__version__}") + print(f"CUDA version: {torch.version.cuda}") + + # Perform a matrix multiplication on CUDA and print the result + result = torch.randn(2, 4).cuda() @ torch.randn(4, 1).cuda() + print(f"Matrix multiplication result: {result}") + + # Check CUDA availability and device details + print(f'Number of CUDA devices: {torch.cuda.device_count()}') + if torch.cuda.device_count() > 0: + print(f'Device name: {torch.cuda.get_device_name(0)}') + else: + print("No CUDA devices available.") + +@sgl.function +def multi_turn_question(s, question_1, question_2): + s += sgl.system("You are a helpful assistant.") + s += sgl.user(question_1) + s += sgl.assistant(sgl.gen("answer_1", max_tokens=256)) + s += sgl.user(question_2) + s += sgl.assistant(sgl.gen("answer_2", max_tokens=256)) + + + +def test_sglang(): + print('\n----- Test sglang ---') + state = multi_turn_question.run( + question_1="What is the capital of the United States?", + question_2="List two local attractions.", + ) + + for m in state.messages(): + print(m["role"], ":", m["content"]) + + print("\n-- answer_1 --\n", state["answer_1"]) + + +if __name__ == "__main__": + import time + start_time = time.time() + sgl.set_default_backend(sgl.OpenAI("gpt-4")) + + test_sglang() + print(f"Time taken: {time.time() - start_time:.2f} seconds, or {(time.time() - start_time) / 60:.2f} minutes, or {(time.time() - start_time) / 3600:.2f} hours.\a") \ No newline at end of file From 9ef093c9e2c0b657803d14de82b5cec08b401223 Mon Sep 17 00:00:00 2001 From: ChuyueSun Date: Sun, 19 May 2024 19:17:05 -0700 Subject: [PATCH 3/7] update --- test_sglang.py | 19 +------------------ 1 file changed, 1 insertion(+), 18 deletions(-) diff --git a/test_sglang.py b/test_sglang.py index 34b8148..6b68cc8 100644 --- a/test_sglang.py +++ b/test_sglang.py @@ -1,6 +1,5 @@ -# copy pasted from https://docs.vllm.ai/en/latest/getting_started/quickstart.html -# do export VLLM_USE_MODELSCOPE=True + import argparse from typing import Dict, List import os @@ -8,22 +7,6 @@ import sglang as sgl from sglang import OpenAI, assistant, gen, set_default_backend, system, user -def test_pytorch(): - print('\n----- Test PyTorch ---') - # Print the PyTorch version and CUDA version - print(f"PyTorch version: {torch.__version__}") - print(f"CUDA version: {torch.version.cuda}") - - # Perform a matrix multiplication on CUDA and print the result - result = torch.randn(2, 4).cuda() @ torch.randn(4, 1).cuda() - print(f"Matrix multiplication result: {result}") - - # Check CUDA availability and device details - print(f'Number of CUDA devices: {torch.cuda.device_count()}') - if torch.cuda.device_count() > 0: - print(f'Device name: {torch.cuda.get_device_name(0)}') - else: - print("No CUDA devices available.") @sgl.function def multi_turn_question(s, question_1, question_2): From 095180589b8972fae85a0744d18a31311ab08aa0 Mon Sep 17 00:00:00 2001 From: ChuyueSun Date: Mon, 20 May 2024 00:02:33 -0700 Subject: [PATCH 4/7] add conv_sgl tests --- .gitignore | 1 + .gitmodules | 1 + pantograph/server.py | 155 ++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 156 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 4bd6cec..670b156 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ # Output /dist +/venv pantograph/pantograph pantograph/lean-toolchain diff --git a/.gitmodules b/.gitmodules index caa4dcb..6b49ada 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,4 @@ [submodule "src"] path = src url = https://git.leni.sh/aniva/Pantograph.git + \ No newline at end of file diff --git a/pantograph/server.py b/pantograph/server.py index 52a7661..86c11ad 100644 --- a/pantograph/server.py +++ b/pantograph/server.py @@ -6,6 +6,41 @@ import json, pexpect, pathlib, unittest from pantograph.expr import Variable, Goal, GoalState, \ Tactic, TacticHave, TacticCalc + +import argparse +from typing import Dict, List +import os +import sglang as sgl + + + + +@sgl.function +def multi_turn_question(s, question_1, question_2): + s += sgl.system("You are a helpful assistant.") + s += sgl.user(question_1) + s += sgl.assistant(sgl.gen("answer_1", max_tokens=256)) + s += sgl.user(question_2) + s += sgl.assistant(sgl.gen("answer_2", max_tokens=256)) + + +@sgl.function +def select_tactic(s, state): + s += sgl.system("You are an expert in Lean. Choose the next one tactic to run given the current proof state and goals.") + s += sgl.user("The current proof state: GoalState(state_id=0, goals=[Goal(variables=[], target='∀ (a b: Nat), (b = 2) -> 1 + a + 1 = a + b', name=None, is_conversion=False)])") + s += sgl.assistant("```intros a b h```") + s += sgl.user("The current proof state: GoalState(state_id=1, goals=[Goal(variables=[Variable(t='Nat', v=None, name='a'), Variable(t='Nat', v=None, name='b'), Variable(t='b = 2', v=None, name='h')], target='1 + a + 1 = a + b', name=None, is_conversion=False)])") + s += sgl.assistant('TacticCalc("1 + a + 1 = a + 1 + 1")') + s += sgl.user("The current proof state: " + str(state)) + with s.copy() as tmp: + tmp += sgl.assistant(sgl.gen("tactic", max_tokens=64)) + print("==tmp===") + print(tmp["tactic"]) + tactic = extract_code_from_llm_output(tmp["tactic"]) + s += sgl.assistant("```"+tactic+"```") + + + def _get_proc_cwd(): return pathlib.Path(__file__).parent def _get_proc_path(): @@ -112,7 +147,21 @@ def get_version(): stdout=subprocess.PIPE, cwd=_get_proc_cwd()) as p: return p.communicate()[0].decode('utf-8').strip() - + +def extract_code_from_llm_output(reply): + i = reply.find("```lean") + if i != -1: + reply = reply[i + 7:] + i = reply.find("```") + reply = reply[:i] + return reply + i = reply.find("```") + if i != -1: + reply = reply[i + 3:] + i = reply.find("```") + reply = reply[:i] + return reply + return reply class TestServer(unittest.TestCase): @@ -132,6 +181,75 @@ class TestServer(unittest.TestCase): )]) self.assertEqual(str(state1.goals[0]),"a : Prop\n⊢ ∀ (q : Prop), a ∨ q → q ∨ a") + def test_conv_calc_sgl(self): + sgl.set_default_backend(sgl.OpenAI("gpt-4")) + + server = Server() + state0 = server.goal_start("∀ (a b: Nat), (b = 2) -> 1 + a + 1 = a + b") + print("==========state0============") + print(state0) + variables = [ + Variable(name="a", t="Nat"), + Variable(name="b", t="Nat"), + Variable(name="h", t="b = 2"), + ] + + state1 = server.goal_tactic(state0, goal_id=0, tactic="intro a b h") + print("==========state1============") + print(state1) + state2 = server.goal_tactic(state1, goal_id=0, tactic=TacticCalc("1 + a + 1 = a + 1 + 1")) + print("==========state2============") + print(state2) + self.assertEqual(state2.goals, [ + Goal( + variables, + target="1 + a + 1 = a + 1 + 1", + name='calc', + ), + Goal( + variables, + target="a + 1 + 1 = a + b", + ), + ]) + state = select_tactic.run(str(state2)) + for m in state.messages(): + print(m["role"], ":", m["content"]) + + print("\n-- tactic --\n", state.stream_executor.variables) + print(state.stream_executor.arguments) + + + # print("==========state2============") + # print(state2) + # state_c1 = server.goal_conv_begin(state2, goal_id=0) + # print("==========state c1============") + # print(state_c1) + # state_c2 = server.goal_tactic(state_c1, goal_id=0, tactic="rhs") + # print("==========state c2============") + # print(state_c2) + # state_c3 = server.goal_tactic(state_c2, goal_id=0, tactic="rw [Nat.add_comm]") + # print("==========state c3============") + # print(state_c3) + # state_c4 = server.goal_conv_end(state_c3) + # print("==========state c4============") + # print(state_c4) + + # state_c5 = server.goal_tactic(state_c4, goal_id=0, tactic="rfl") + # print("==========state c5============") + # print(state_c5) + # self.assertTrue(state_c5.is_solved) + + # print() + + # state3 = server.goal_tactic(state2, goal_id=1, tactic=TacticCalc("_ = a + 2")) + # print("==========state3============") + # print(state3) + # state4 = server.goal_tactic(state3, goal_id=0, tactic="rw [Nat.add_assoc]") + # print("==========state4============") + # print(state4) + # self.assertTrue(state4.is_solved) + + def test_conv_calc(self): server = Server() state0 = server.goal_start("∀ (a b: Nat), (b = 2) -> 1 + a + 1 = a + b") @@ -154,17 +272,52 @@ class TestServer(unittest.TestCase): target="a + 1 + 1 = a + b", ), ]) + print("==========state2============") + print(state2) state_c1 = server.goal_conv_begin(state2, goal_id=0) + print("==========state c1============") + print(state_c1) state_c2 = server.goal_tactic(state_c1, goal_id=0, tactic="rhs") + print("==========state c2============") + print(state_c2) state_c3 = server.goal_tactic(state_c2, goal_id=0, tactic="rw [Nat.add_comm]") + print("==========state c3============") + print(state_c3) state_c4 = server.goal_conv_end(state_c3) + print("==========state c4============") + print(state_c4) + state_c5 = server.goal_tactic(state_c4, goal_id=0, tactic="rfl") + print("==========state c5============") + print(state_c5) self.assertTrue(state_c5.is_solved) + + print() state3 = server.goal_tactic(state2, goal_id=1, tactic=TacticCalc("_ = a + 2")) + print("==========state3============") + print(state3) state4 = server.goal_tactic(state3, goal_id=0, tactic="rw [Nat.add_assoc]") + print("==========state4============") + print(state4) self.assertTrue(state4.is_solved) + def test_sglang_openai(self): + sgl.set_default_backend(sgl.OpenAI("gpt-4")) + print('\n----- Test sglang ---') + state = multi_turn_question.run( + question_1="What is the capital of the United States?", + question_2="List two local attractions.", + ) + + for m in state.messages(): + print(m["role"], ":", m["content"]) + + print("\n-- answer_1 --\n", state["answer_1"]) + + if __name__ == '__main__': + unittest.main() + From 67b3ec969bd7c9982fe58f38f3c5e16a3191ea23 Mon Sep 17 00:00:00 2001 From: ChuyueSun Date: Mon, 20 May 2024 00:26:52 -0700 Subject: [PATCH 5/7] apply sgl tactic --- pantograph/server.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/pantograph/server.py b/pantograph/server.py index 86c11ad..633913e 100644 --- a/pantograph/server.py +++ b/pantograph/server.py @@ -38,6 +38,7 @@ def select_tactic(s, state): print(tmp["tactic"]) tactic = extract_code_from_llm_output(tmp["tactic"]) s += sgl.assistant("```"+tactic+"```") + return tactic @@ -212,11 +213,19 @@ class TestServer(unittest.TestCase): ), ]) state = select_tactic.run(str(state2)) + tactic = state.ret_value for m in state.messages(): print(m["role"], ":", m["content"]) print("\n-- tactic --\n", state.stream_executor.variables) - print(state.stream_executor.arguments) + + state3 = server.goal_tactic(state2, goal_id=1, tactic=tactic) + print("==========state3============") + print(state3) + # state4 = server.goal_tactic(state3, goal_id=0, tactic="rw [Nat.add_assoc]") + # print("==========state4============") + # print(state4) + # self.assertTrue(state4.is_solved) # print("==========state2============") @@ -241,13 +250,6 @@ class TestServer(unittest.TestCase): # print() - # state3 = server.goal_tactic(state2, goal_id=1, tactic=TacticCalc("_ = a + 2")) - # print("==========state3============") - # print(state3) - # state4 = server.goal_tactic(state3, goal_id=0, tactic="rw [Nat.add_assoc]") - # print("==========state4============") - # print(state4) - # self.assertTrue(state4.is_solved) def test_conv_calc(self): From 46459763bebae7ebedfb71543a4879722952585e Mon Sep 17 00:00:00 2001 From: ChuyueSun Date: Mon, 20 May 2024 18:06:15 -0700 Subject: [PATCH 6/7] refactor --- pantograph/gen_tactic.py | 140 ++++++++++++++++++++++++++++++++++ pantograph/server.py | 159 +-------------------------------------- 2 files changed, 142 insertions(+), 157 deletions(-) create mode 100644 pantograph/gen_tactic.py diff --git a/pantograph/gen_tactic.py b/pantograph/gen_tactic.py new file mode 100644 index 0000000..68c7ccb --- /dev/null +++ b/pantograph/gen_tactic.py @@ -0,0 +1,140 @@ +from pantograph.server import Server +from pantograph.expr import Variable, Goal, TacticCalc +import unittest +import sglang as sgl + + + + +@sgl.function +def multi_turn_question(s, question_1, question_2): + s += sgl.system("You are a helpful assistant.") + s += sgl.user(question_1) + s += sgl.assistant(sgl.gen("answer_1", max_tokens=256)) + s += sgl.user(question_2) + s += sgl.assistant(sgl.gen("answer_2", max_tokens=256)) + + +@sgl.function +def select_tactic(s, state): + s += sgl.system("You are an expert in Lean. Choose the next one tactic to run given the current proof state and goals.") + s += sgl.user("The current proof state: GoalState(state_id=0, goals=[Goal(variables=[], target='∀ (a b: Nat), (b = 2) -> 1 + a + 1 = a + b', name=None, is_conversion=False)])") + s += sgl.assistant("```intros a b h```") + s += sgl.user("The current proof state: GoalState(state_id=1, goals=[Goal(variables=[Variable(t='Nat', v=None, name='a'), Variable(t='Nat', v=None, name='b'), Variable(t='b = 2', v=None, name='h')], target='1 + a + 1 = a + b', name=None, is_conversion=False)])") + s += sgl.assistant('TacticCalc("1 + a + 1 = a + 1 + 1")') + s += sgl.user("The current proof state: " + str(state)) + with s.copy() as tmp: + tmp += sgl.assistant(sgl.gen("tactic", max_tokens=64)) + print("==tmp===") + print(tmp["tactic"]) + tactic = extract_code_from_llm_output(tmp["tactic"]) + s += sgl.assistant("```"+tactic+"```") + return tactic + + + +def extract_code_from_llm_output(reply): + i = reply.find("```lean") + if i != -1: + reply = reply[i + 7:] + i = reply.find("```") + reply = reply[:i] + return reply + i = reply.find("```") + if i != -1: + reply = reply[i + 3:] + i = reply.find("```") + reply = reply[:i] + return reply + return reply + +class TestServerSGL(unittest.TestCase): + + def test_conv_calc_sgl(self): + sgl.set_default_backend(sgl.OpenAI("gpt-4")) + + server = Server() + state0 = server.goal_start("∀ (a b: Nat), (b = 2) -> 1 + a + 1 = a + b") + print("==========state0============") + print(state0) + variables = [ + Variable(name="a", t="Nat"), + Variable(name="b", t="Nat"), + Variable(name="h", t="b = 2"), + ] + + state1 = server.goal_tactic(state0, goal_id=0, tactic="intro a b h") + print("==========state1============") + print(state1) + state2 = server.goal_tactic(state1, goal_id=0, tactic=TacticCalc("1 + a + 1 = a + 1 + 1")) + print("==========state2============") + print(state2) + self.assertEqual(state2.goals, [ + Goal( + variables, + target="1 + a + 1 = a + 1 + 1", + name='calc', + ), + Goal( + variables, + target="a + 1 + 1 = a + b", + ), + ]) + state = select_tactic.run(str(state2)) + tactic = state.ret_value + for m in state.messages(): + print(m["role"], ":", m["content"]) + + print("\n-- tactic --\n", tactic) + + state3 = server.goal_tactic(state2, goal_id=1, tactic=tactic) + print("==========state3============") + print(state3) + # state4 = server.goal_tactic(state3, goal_id=0, tactic="rw [Nat.add_assoc]") + # print("==========state4============") + # print(state4) + # self.assertTrue(state4.is_solved) + + + # print("==========state2============") + # print(state2) + # state_c1 = server.goal_conv_begin(state2, goal_id=0) + # print("==========state c1============") + # print(state_c1) + # state_c2 = server.goal_tactic(state_c1, goal_id=0, tactic="rhs") + # print("==========state c2============") + # print(state_c2) + # state_c3 = server.goal_tactic(state_c2, goal_id=0, tactic="rw [Nat.add_comm]") + # print("==========state c3============") + # print(state_c3) + # state_c4 = server.goal_conv_end(state_c3) + # print("==========state c4============") + # print(state_c4) + + # state_c5 = server.goal_tactic(state_c4, goal_id=0, tactic="rfl") + # print("==========state c5============") + # print(state_c5) + # self.assertTrue(state_c5.is_solved) + + # print() + + + def test_sglang_openai(self): + sgl.set_default_backend(sgl.OpenAI("gpt-4")) + + print('\n----- Test sglang ---') + state = multi_turn_question.run( + question_1="What is the capital of the United States?", + question_2="List two local attractions.", + ) + + for m in state.messages(): + print(m["role"], ":", m["content"]) + + print("\n-- answer_1 --\n", state["answer_1"]) + + +if __name__ == '__main__': + + unittest.main() + diff --git a/pantograph/server.py b/pantograph/server.py index 633913e..f652fbb 100644 --- a/pantograph/server.py +++ b/pantograph/server.py @@ -6,42 +6,6 @@ import json, pexpect, pathlib, unittest from pantograph.expr import Variable, Goal, GoalState, \ Tactic, TacticHave, TacticCalc - -import argparse -from typing import Dict, List -import os -import sglang as sgl - - - - -@sgl.function -def multi_turn_question(s, question_1, question_2): - s += sgl.system("You are a helpful assistant.") - s += sgl.user(question_1) - s += sgl.assistant(sgl.gen("answer_1", max_tokens=256)) - s += sgl.user(question_2) - s += sgl.assistant(sgl.gen("answer_2", max_tokens=256)) - - -@sgl.function -def select_tactic(s, state): - s += sgl.system("You are an expert in Lean. Choose the next one tactic to run given the current proof state and goals.") - s += sgl.user("The current proof state: GoalState(state_id=0, goals=[Goal(variables=[], target='∀ (a b: Nat), (b = 2) -> 1 + a + 1 = a + b', name=None, is_conversion=False)])") - s += sgl.assistant("```intros a b h```") - s += sgl.user("The current proof state: GoalState(state_id=1, goals=[Goal(variables=[Variable(t='Nat', v=None, name='a'), Variable(t='Nat', v=None, name='b'), Variable(t='b = 2', v=None, name='h')], target='1 + a + 1 = a + b', name=None, is_conversion=False)])") - s += sgl.assistant('TacticCalc("1 + a + 1 = a + 1 + 1")') - s += sgl.user("The current proof state: " + str(state)) - with s.copy() as tmp: - tmp += sgl.assistant(sgl.gen("tactic", max_tokens=64)) - print("==tmp===") - print(tmp["tactic"]) - tactic = extract_code_from_llm_output(tmp["tactic"]) - s += sgl.assistant("```"+tactic+"```") - return tactic - - - def _get_proc_cwd(): return pathlib.Path(__file__).parent def _get_proc_path(): @@ -148,21 +112,7 @@ def get_version(): stdout=subprocess.PIPE, cwd=_get_proc_cwd()) as p: return p.communicate()[0].decode('utf-8').strip() - -def extract_code_from_llm_output(reply): - i = reply.find("```lean") - if i != -1: - reply = reply[i + 7:] - i = reply.find("```") - reply = reply[:i] - return reply - i = reply.find("```") - if i != -1: - reply = reply[i + 3:] - i = reply.find("```") - reply = reply[:i] - return reply - return reply + class TestServer(unittest.TestCase): @@ -182,76 +132,6 @@ class TestServer(unittest.TestCase): )]) self.assertEqual(str(state1.goals[0]),"a : Prop\n⊢ ∀ (q : Prop), a ∨ q → q ∨ a") - def test_conv_calc_sgl(self): - sgl.set_default_backend(sgl.OpenAI("gpt-4")) - - server = Server() - state0 = server.goal_start("∀ (a b: Nat), (b = 2) -> 1 + a + 1 = a + b") - print("==========state0============") - print(state0) - variables = [ - Variable(name="a", t="Nat"), - Variable(name="b", t="Nat"), - Variable(name="h", t="b = 2"), - ] - - state1 = server.goal_tactic(state0, goal_id=0, tactic="intro a b h") - print("==========state1============") - print(state1) - state2 = server.goal_tactic(state1, goal_id=0, tactic=TacticCalc("1 + a + 1 = a + 1 + 1")) - print("==========state2============") - print(state2) - self.assertEqual(state2.goals, [ - Goal( - variables, - target="1 + a + 1 = a + 1 + 1", - name='calc', - ), - Goal( - variables, - target="a + 1 + 1 = a + b", - ), - ]) - state = select_tactic.run(str(state2)) - tactic = state.ret_value - for m in state.messages(): - print(m["role"], ":", m["content"]) - - print("\n-- tactic --\n", state.stream_executor.variables) - - state3 = server.goal_tactic(state2, goal_id=1, tactic=tactic) - print("==========state3============") - print(state3) - # state4 = server.goal_tactic(state3, goal_id=0, tactic="rw [Nat.add_assoc]") - # print("==========state4============") - # print(state4) - # self.assertTrue(state4.is_solved) - - - # print("==========state2============") - # print(state2) - # state_c1 = server.goal_conv_begin(state2, goal_id=0) - # print("==========state c1============") - # print(state_c1) - # state_c2 = server.goal_tactic(state_c1, goal_id=0, tactic="rhs") - # print("==========state c2============") - # print(state_c2) - # state_c3 = server.goal_tactic(state_c2, goal_id=0, tactic="rw [Nat.add_comm]") - # print("==========state c3============") - # print(state_c3) - # state_c4 = server.goal_conv_end(state_c3) - # print("==========state c4============") - # print(state_c4) - - # state_c5 = server.goal_tactic(state_c4, goal_id=0, tactic="rfl") - # print("==========state c5============") - # print(state_c5) - # self.assertTrue(state_c5.is_solved) - - # print() - - - def test_conv_calc(self): server = Server() state0 = server.goal_start("∀ (a b: Nat), (b = 2) -> 1 + a + 1 = a + b") @@ -274,52 +154,17 @@ class TestServer(unittest.TestCase): target="a + 1 + 1 = a + b", ), ]) - print("==========state2============") - print(state2) state_c1 = server.goal_conv_begin(state2, goal_id=0) - print("==========state c1============") - print(state_c1) state_c2 = server.goal_tactic(state_c1, goal_id=0, tactic="rhs") - print("==========state c2============") - print(state_c2) state_c3 = server.goal_tactic(state_c2, goal_id=0, tactic="rw [Nat.add_comm]") - print("==========state c3============") - print(state_c3) state_c4 = server.goal_conv_end(state_c3) - print("==========state c4============") - print(state_c4) - state_c5 = server.goal_tactic(state_c4, goal_id=0, tactic="rfl") - print("==========state c5============") - print(state_c5) self.assertTrue(state_c5.is_solved) - - print() state3 = server.goal_tactic(state2, goal_id=1, tactic=TacticCalc("_ = a + 2")) - print("==========state3============") - print(state3) state4 = server.goal_tactic(state3, goal_id=0, tactic="rw [Nat.add_assoc]") - print("==========state4============") - print(state4) self.assertTrue(state4.is_solved) - def test_sglang_openai(self): - sgl.set_default_backend(sgl.OpenAI("gpt-4")) - print('\n----- Test sglang ---') - state = multi_turn_question.run( - question_1="What is the capital of the United States?", - question_2="List two local attractions.", - ) - - for m in state.messages(): - print(m["role"], ":", m["content"]) - - print("\n-- answer_1 --\n", state["answer_1"]) - - if __name__ == '__main__': - - unittest.main() - + unittest.main() \ No newline at end of file From 8c2f27681ca696125a78aa5954a6cfc2807a472e Mon Sep 17 00:00:00 2001 From: ChuyueSun Date: Mon, 20 May 2024 18:06:59 -0700 Subject: [PATCH 7/7] remove unused --- test_sglang.py | 40 ---------------------------------------- 1 file changed, 40 deletions(-) delete mode 100644 test_sglang.py diff --git a/test_sglang.py b/test_sglang.py deleted file mode 100644 index 6b68cc8..0000000 --- a/test_sglang.py +++ /dev/null @@ -1,40 +0,0 @@ - - -import argparse -from typing import Dict, List -import os -import sglang as sgl -from sglang import OpenAI, assistant, gen, set_default_backend, system, user - - - -@sgl.function -def multi_turn_question(s, question_1, question_2): - s += sgl.system("You are a helpful assistant.") - s += sgl.user(question_1) - s += sgl.assistant(sgl.gen("answer_1", max_tokens=256)) - s += sgl.user(question_2) - s += sgl.assistant(sgl.gen("answer_2", max_tokens=256)) - - - -def test_sglang(): - print('\n----- Test sglang ---') - state = multi_turn_question.run( - question_1="What is the capital of the United States?", - question_2="List two local attractions.", - ) - - for m in state.messages(): - print(m["role"], ":", m["content"]) - - print("\n-- answer_1 --\n", state["answer_1"]) - - -if __name__ == "__main__": - import time - start_time = time.time() - sgl.set_default_backend(sgl.OpenAI("gpt-4")) - - test_sglang() - print(f"Time taken: {time.time() - start_time:.2f} seconds, or {(time.time() - start_time) / 60:.2f} minutes, or {(time.time() - start_time) / 3600:.2f} hours.\a") \ No newline at end of file