Merge remote-tracking branch 'origin/main' into brando

This commit is contained in:
Brando Miranda 2024-06-03 11:11:06 -07:00
commit c73e56630e
1 changed files with 133 additions and 42 deletions

View File

@ -1,10 +1,82 @@
from pantograph.server import Server from pantograph.server import Server, ServerError
from pantograph.expr import Variable, Goal, TacticCalc from pantograph.expr import Variable, Goal, TacticCalc
import unittest import unittest
import sglang as sgl import sglang as sgl
LEAN4_INTRO = '''/-- A sequence `u` of real numbers converges to `l` if `∀ ε > 0, ∃ N, ∀ n ≥ N, |u_n - l| ≤ ε`.
This condition will be spelled `seq_limit u l`. -/
def seq_limit (u : ) (l : ) : Prop :=
ε > 0, N, n N, |u n - l| ε
/- In the above definition, note that the `n`-th term of the sequence `u` is denoted
simply by `u n`.
Similarly, in the next definition, `f x` is what we would write `f(x)` on paper.
Also note that implication is denoted by a single arrow (we'll explain why later). -/
/-- A function`f : ` is continuous at `x₀` if
` ε > 0, δ > 0, x, |x - x₀| δ |f(x) - f(x₀)| ε`.
This condition will be spelled `continuous_at f x₀`.-/
def continuous_at (f : ) (x₀ : ) : Prop :=
ε > 0, δ > 0, x, |x - x₀| δ |f x - f x₀| ε
/-- Now we claim that if `f` is continuous at `x₀` then it is sequentially continuous
at `x₀`: for any sequence `u` converging to `x₀`, the sequence `f u` converges
to `f x₀`. -/
example (f : ) (u : ) (x₀ : ) (hu : seq_limit u x₀) (hf : continuous_at f x₀) :
seq_limit (f u) (f x₀) := by { -- This `by` keyword marks the beginning of the proof
-- Put your text cursor here and watch the Lean InfoView panel to the right.
-- Then move your cursor from line to line in the proof while monitoring the Infoview.
-- Our goal is to prove that, for any positive `ε`, there exists a natural
-- number `N` such that, for any natural number `n` at least `N`,
-- `|f(u_n) - f(x₀)|` is at most `ε`.
unfold seq_limit
-- Fix a positive number `ε`.
intros ε
-- By assumption on `f` applied to this positive `ε`, we get a positive `δ`
-- such that, for all real number `x`, if `|x - x₀| δ` then `|f(x) - f(x₀)| ε` (1).
obtain δ, δ_pos, Hf : δ > 0, x, |x - x₀| δ |f x - f x₀| ε := hf ε
-- The assumption on `u` applied to this `δ` gives a natural number `N` such that
-- for every natural number `n`, if `n N` then `|u_n - x₀| δ` (2).
obtain N, Hu : N, n N, |u n - x₀| δ := hu δ δ_pos
-- Let's prove `N` is suitable.
use N
-- Fix `n` which is at least `N`. Let's prove `|f(u_n) - f(x₀)| ≤ ε`.
intros n hn
-- Thanks to (1) applied to `u_n`, it suffices to prove that `|u_n - x₀| δ`.
apply Hf
-- This follows from property (2) and our assumption on `n`.
exact Hu n hn
-- This finishes the proof!
}
/-
Now that this proof is over, you can use the file explorer to the
left of this panel to open the file `Exercises > 01Rewriting.lean`.
-/'''
LEAN4_REWRITE = '''Rewrite tactic tutorial:
example (a b c : Nat) : a + b + c = a + c + b := by
rw [Nat.add_assoc, Nat.add_comm b, Nat.add_assoc]
example (a b c : Nat) : a + b + c = a + c + b := by
rw [Nat.add_assoc, Nat.add_assoc, Nat.add_comm b]
example (a b c : Nat) : a + b + c = a + c + b := by
rw [Nat.add_assoc, Nat.add_assoc, Nat.add_comm _ b]
example (f : Nat Nat) (a : Nat) (h : a + 0 = 0) : f a = f 0 := by
rw [Nat.add_zero] at h
rw [h]
def Tuple (α : Type) (n : Nat) :=
{ as : List α // as.length = n }
example (n : Nat) (h : n = 0) (t : Tuple α n) : Tuple α 0 := by
rw [h] at t
exact t
'''
@sgl.function @sgl.function
def multi_turn_question(s, question_1, question_2): def multi_turn_question(s, question_1, question_2):
@ -16,22 +88,37 @@ def multi_turn_question(s, question_1, question_2):
@sgl.function @sgl.function
def select_tactic(s, state): def select_tactic(s, server, state, goal_id, feedback_turns = 5):
s += sgl.system("You are an expert in Lean. Choose the next one tactic to run given the current proof state and goals.") s += sgl.system("You are an expert in Lean. Choose the next one tactic to run given the current proof state and goals.")
s += sgl.user(LEAN4_REWRITE)
s += sgl.user("The current proof state: GoalState(state_id=0, goals=[Goal(variables=[], target='∀ (a b: Nat), (b = 2) -> 1 + a + 1 = a + b', name=None, is_conversion=False)])") s += sgl.user("The current proof state: GoalState(state_id=0, goals=[Goal(variables=[], target='∀ (a b: Nat), (b = 2) -> 1 + a + 1 = a + b', name=None, is_conversion=False)])")
s += sgl.assistant("```intros a b h```") s += sgl.assistant("```intros a b h```")
s += sgl.user("The current proof state: GoalState(state_id=1, goals=[Goal(variables=[Variable(t='Nat', v=None, name='a'), Variable(t='Nat', v=None, name='b'), Variable(t='b = 2', v=None, name='h')], target='1 + a + 1 = a + b', name=None, is_conversion=False)])") s += sgl.user("The current proof state: GoalState(state_id=1, goals=[Goal(variables=[Variable(t='Nat', v=None, name='a'), Variable(t='Nat', v=None, name='b'), Variable(t='b = 2', v=None, name='h')], target='1 + a + 1 = a + b', name=None, is_conversion=False)])")
s += sgl.assistant('TacticCalc("1 + a + 1 = a + 1 + 1")') s += sgl.assistant('TacticCalc("1 + a + 1 = a + 1 + 1")')
s += sgl.user("The current proof state: " + str(state)) s += sgl.user("The current proof state: " + str(state))
with s.copy() as tmp: for i in range(feedback_turns):
tmp += sgl.assistant(sgl.gen("tactic", max_tokens=64)) with s.copy() as tmp:
print("==tmp===") tmp += sgl.assistant(sgl.gen("tactic", max_tokens=64))
print(tmp["tactic"]) print("==tmp===")
tactic = extract_code_from_llm_output(tmp["tactic"]) print(tmp["tactic"])
s += sgl.assistant("```"+tactic+"```") tactic = extract_code_from_llm_output(tmp["tactic"])
return tactic s += sgl.assistant("```"+tactic+"```")
success, new_state = apply_tactic(server, state, goal_id, tactic)
if not success:
with s.user():
s += "This answer got Lean compile error:\n" + str(new_state) + "\n"
s += "Please try again by taking the Lean compiler feedback."
else:
return new_state
def apply_tactic(server, state, goal_id, tactic):
try:
new_state = server.goal_tactic(state, goal_id=goal_id, tactic=tactic)
except ServerError as e:
return False, e
return True, new_state
def extract_code_from_llm_output(reply): def extract_code_from_llm_output(reply):
i = reply.find("```lean") i = reply.find("```lean")
@ -51,6 +138,7 @@ def extract_code_from_llm_output(reply):
class TestServerSGL(unittest.TestCase): class TestServerSGL(unittest.TestCase):
def test_conv_calc_sgl(self): def test_conv_calc_sgl(self):
n_trails = 5
sgl.set_default_backend(sgl.OpenAI("gpt-4")) sgl.set_default_backend(sgl.OpenAI("gpt-4"))
server = Server() server = Server()
@ -80,43 +168,46 @@ class TestServerSGL(unittest.TestCase):
target="a + 1 + 1 = a + b", target="a + 1 + 1 = a + b",
), ),
]) ])
state = select_tactic.run(str(state2)) state3 = None
tactic = state.ret_value for i in range(n_trails):
for m in state.messages(): print(f"===============trail {str(i)}============")
print(m["role"], ":", m["content"]) try:
state = select_tactic.run(server, state2, goal_id = 1)
state3 = state.ret_value
for m in state.messages():
print(m["role"], ":", m["content"])
print("\n-- new state --\n", state3)
break
except ServerError as e:
print(f"server error: {e}")
continue
state3 = server.goal_tactic(state2, goal_id=1, tactic=TacticCalc("_ = a + 2"))
print("\n-- tactic --\n", tactic)
state3 = server.goal_tactic(state2, goal_id=1, tactic=tactic)
print("==========state3============") print("==========state3============")
print(state3) print(state3)
# state4 = server.goal_tactic(state3, goal_id=0, tactic="rw [Nat.add_assoc]") state4 = None
# print("==========state4============") for i in range(n_trails):
# print(state4) print(f"===============trail {str(i)}============")
# self.assertTrue(state4.is_solved) try:
state = select_tactic.run(server, state3, goal_id = 0)
state4 = state.ret_value
for m in state.messages():
print(m["role"], ":", m["content"])
print("\n-- new state --\n", state4)
break
except ServerError as e:
print(f"server error: {e}")
continue
# print("==========state2============") state4 = server.goal_tactic(state3, goal_id=0, tactic="rw [Nat.add_assoc]")
# print(state2) print("==========state4============")
# state_c1 = server.goal_conv_begin(state2, goal_id=0) print(state4)
# print("==========state c1============") self.assertTrue(state4.is_solved)
# print(state_c1)
# state_c2 = server.goal_tactic(state_c1, goal_id=0, tactic="rhs")
# print("==========state c2============")
# print(state_c2)
# state_c3 = server.goal_tactic(state_c2, goal_id=0, tactic="rw [Nat.add_comm]")
# print("==========state c3============")
# print(state_c3)
# state_c4 = server.goal_conv_end(state_c3)
# print("==========state c4============")
# print(state_c4)
# state_c5 = server.goal_tactic(state_c4, goal_id=0, tactic="rfl")
# print("==========state c5============")
# print(state_c5)
# self.assertTrue(state_c5.is_solved)
# print()
def test_sglang_openai(self): def test_sglang_openai(self):