From f9fe626aa8a24498f809b553927701fd18663f56 Mon Sep 17 00:00:00 2001 From: Chuyue Sun Date: Sun, 2 Jun 2024 19:51:20 -0700 Subject: [PATCH] update llm gen tactics tests --- pantograph/gen_tactic.py | 85 +++++++++++++++++++++------------------- 1 file changed, 45 insertions(+), 40 deletions(-) diff --git a/pantograph/gen_tactic.py b/pantograph/gen_tactic.py index 6db1fb8..24442ac 100644 --- a/pantograph/gen_tactic.py +++ b/pantograph/gen_tactic.py @@ -56,7 +56,7 @@ Now that this proof is over, you can use the file explorer to the left of this panel to open the file `Exercises > 01Rewriting.lean`. -/''' -LEAN4_REWRITE = ''' +LEAN4_REWRITE = '''Rewrite tactic tutorial: example (a b c : Nat) : a + b + c = a + c + b := by rw [Nat.add_assoc, Nat.add_comm b, ← Nat.add_assoc] @@ -65,6 +65,17 @@ example (a b c : Nat) : a + b + c = a + c + b := by example (a b c : Nat) : a + b + c = a + c + b := by rw [Nat.add_assoc, Nat.add_assoc, Nat.add_comm _ b] + +example (f : Nat → Nat) (a : Nat) (h : a + 0 = 0) : f a = f 0 := by + rw [Nat.add_zero] at h + rw [h] + +def Tuple (α : Type) (n : Nat) := + { as : List α // as.length = n } + +example (n : Nat) (h : n = 0) (t : Tuple α n) : Tuple α 0 := by + rw [h] at t + exact t ''' @sgl.function @@ -161,63 +172,57 @@ class TestServerSGL(unittest.TestCase): for i in range(n_trails): print(f"===============trail {str(i)}============") try: - state = select_tactic.run(server, state2, goal_id = 0) + state = select_tactic.run(server, state2, goal_id = 1) state3 = state.ret_value for m in state.messages(): print(m["role"], ":", m["content"]) print("\n-- new state --\n", state3) + break except ServerError as e: print(f"server error: {e}") continue - state3 = server.goal_tactic(state2, goal_id=0, tactic="rw [Nat.add_assoc]") + state3 = server.goal_tactic(state2, goal_id=1, tactic=TacticCalc("_ = a + 2")) print("==========state3============") print(state3) - # state4 = server.goal_tactic(state3, goal_id=0, tactic="rw [Nat.add_assoc]") - # print("==========state4============") - # print(state4) - # self.assertTrue(state4.is_solved) + state4 = None + for i in range(n_trails): + print(f"===============trail {str(i)}============") + try: + state = select_tactic.run(server, state3, goal_id = 0) + state4 = state.ret_value + for m in state.messages(): + print(m["role"], ":", m["content"]) + + print("\n-- new state --\n", state4) + break + + except ServerError as e: + print(f"server error: {e}") + continue + + state4 = server.goal_tactic(state3, goal_id=0, tactic="rw [Nat.add_assoc]") + print("==========state4============") + print(state4) + self.assertTrue(state4.is_solved) - # print("==========state2============") - # print(state2) - # state_c1 = server.goal_conv_begin(state2, goal_id=0) - # print("==========state c1============") - # print(state_c1) - # state_c2 = server.goal_tactic(state_c1, goal_id=0, tactic="rhs") - # print("==========state c2============") - # print(state_c2) - # state_c3 = server.goal_tactic(state_c2, goal_id=0, tactic="rw [Nat.add_comm]") - # print("==========state c3============") - # print(state_c3) - # state_c4 = server.goal_conv_end(state_c3) - # print("==========state c4============") - # print(state_c4) + def test_sglang_openai(self): + sgl.set_default_backend(sgl.OpenAI("gpt-4")) - # state_c5 = server.goal_tactic(state_c4, goal_id=0, tactic="rfl") - # print("==========state c5============") - # print(state_c5) - # self.assertTrue(state_c5.is_solved) + print('\n----- Test sglang ---') + state = multi_turn_question.run( + question_1="What is the capital of the United States?", + question_2="List two local attractions.", + ) - # print() + for m in state.messages(): + print(m["role"], ":", m["content"]) - - # def test_sglang_openai(self): - # sgl.set_default_backend(sgl.OpenAI("gpt-4")) - - # print('\n----- Test sglang ---') - # state = multi_turn_question.run( - # question_1="What is the capital of the United States?", - # question_2="List two local attractions.", - # ) - - # for m in state.messages(): - # print(m["role"], ":", m["content"]) - - # print("\n-- answer_1 --\n", state["answer_1"]) + print("\n-- answer_1 --\n", state["answer_1"]) if __name__ == '__main__':