Pantograph/pantograph/gen_tactic.py

237 lines
8.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from pantograph.server import Server, ServerError, TacticFailure
from pantograph.expr import Variable, Goal, TacticCalc
import unittest
import sglang as sgl
LEAN4_INTRO = '''/-- A sequence `u` of real numbers converges to `l` if `∀ ε > 0, ∃ N, ∀ n ≥ N, |u_n - l| ≤ ε`.
This condition will be spelled `seq_limit u l`. -/
def seq_limit (u : ) (l : ) : Prop :=
∀ ε > 0, ∃ N, ∀ n ≥ N, |u n - l| ≤ ε
/- In the above definition, note that the `n`-th term of the sequence `u` is denoted
simply by `u n`.
Similarly, in the next definition, `f x` is what we would write `f(x)` on paper.
Also note that implication is denoted by a single arrow (we'll explain why later). -/
/-- A function`f : ` is continuous at `x₀` if
`∀ ε > 0, ∃ δ > 0, ∀ x, |x - x₀| ≤ δ ⇒ |f(x) - f(x₀)| ≤ ε`.
This condition will be spelled `continuous_at f x₀`.-/
def continuous_at (f : ) (x₀ : ) : Prop :=
∀ ε > 0, ∃ δ > 0, ∀ x, |x - x₀| ≤ δ → |f x - f x₀| ≤ ε
/-- Now we claim that if `f` is continuous at `x₀` then it is sequentially continuous
at `x₀`: for any sequence `u` converging to `x₀`, the sequence `f ∘ u` converges
to `f x₀`. -/
example (f : ) (u : ) (x₀ : ) (hu : seq_limit u x₀) (hf : continuous_at f x₀) :
seq_limit (f ∘ u) (f x₀) := by { -- This `by` keyword marks the beginning of the proof
-- Put your text cursor here and watch the Lean InfoView panel to the right.
-- Then move your cursor from line to line in the proof while monitoring the Infoview.
-- Our goal is to prove that, for any positive `ε`, there exists a natural
-- number `N` such that, for any natural number `n` at least `N`,
-- `|f(u_n) - f(x₀)|` is at most `ε`.
unfold seq_limit
-- Fix a positive number `ε`.
intros ε hε
-- By assumption on `f` applied to this positive `ε`, we get a positive `δ`
-- such that, for all real number `x`, if `|x - x₀| ≤ δ` then `|f(x) - f(x₀)| ≤ ε` (1).
obtain ⟨δ, δ_pos, Hf⟩ : ∃ δ > 0, ∀ x, |x - x₀| ≤ δ → |f x - f x₀| ≤ ε := hf ε hε
-- The assumption on `u` applied to this `δ` gives a natural number `N` such that
-- for every natural number `n`, if `n ≥ N` then `|u_n - x₀| ≤ δ` (2).
obtain ⟨N, Hu⟩ : ∃ N, ∀ n ≥ N, |u n - x₀| ≤ δ := hu δ δ_pos
-- Let's prove `N` is suitable.
use N
-- Fix `n` which is at least `N`. Let's prove `|f(u_n) - f(x₀)| ≤ ε`.
intros n hn
-- Thanks to (1) applied to `u_n`, it suffices to prove that `|u_n - x₀| ≤ δ`.
apply Hf
-- This follows from property (2) and our assumption on `n`.
exact Hu n hn
-- This finishes the proof!
}
/-
Now that this proof is over, you can use the file explorer to the
left of this panel to open the file `Exercises > 01Rewriting.lean`.
-/'''
LEAN4_REWRITE = '''Rewrite tactic tutorial:
example (a b c : Nat) : a + b + c = a + c + b := by
rw [Nat.add_assoc, Nat.add_comm b, ← Nat.add_assoc]
example (a b c : Nat) : a + b + c = a + c + b := by
rw [Nat.add_assoc, Nat.add_assoc, Nat.add_comm b]
example (a b c : Nat) : a + b + c = a + c + b := by
rw [Nat.add_assoc, Nat.add_assoc, Nat.add_comm _ b]
example (f : Nat → Nat) (a : Nat) (h : a + 0 = 0) : f a = f 0 := by
rw [Nat.add_zero] at h
rw [h]
def Tuple (α : Type) (n : Nat) :=
{ as : List α // as.length = n }
example (n : Nat) (h : n = 0) (t : Tuple α n) : Tuple α 0 := by
rw [h] at t
exact t
'''
@sgl.function
def multi_turn_question(s, question_1, question_2):
s += sgl.system("You are a helpful assistant.")
s += sgl.user(question_1)
s += sgl.assistant(sgl.gen("answer_1", max_tokens=256))
s += sgl.user(question_2)
s += sgl.assistant(sgl.gen("answer_2", max_tokens=256))
@sgl.function
def select_tactic(s, server, state, goal_id, feedback_turns = 5):
s += sgl.system("You are an expert in Lean. Choose the next one tactic to run given the current proof state and goals.")
s += sgl.user(LEAN4_REWRITE)
s += sgl.user("The current proof state: GoalState(state_id=0, goals=[Goal(variables=[], target='∀ (a b: Nat), (b = 2) -> 1 + a + 1 = a + b', name=None, is_conversion=False)])")
s += sgl.assistant("```intros a b h```")
s += sgl.user("The current proof state: GoalState(state_id=1, goals=[Goal(variables=[Variable(t='Nat', v=None, name='a'), Variable(t='Nat', v=None, name='b'), Variable(t='b = 2', v=None, name='h')], target='1 + a + 1 = a + b', name=None, is_conversion=False)])")
s += sgl.assistant('TacticCalc("1 + a + 1 = a + 1 + 1")')
s += sgl.user("The current proof state: " + str(state))
for i in range(feedback_turns):
with s.copy() as tmp:
tmp += sgl.assistant(sgl.gen("tactic", max_tokens=64))
print("==tmp===")
print(tmp["tactic"])
tactic = extract_code_from_llm_output(tmp["tactic"])
s += sgl.assistant("```"+tactic+"```")
success, new_state = apply_tactic(server, state, goal_id, tactic)
print("===execute===")
print(success, new_state )
if not success:
with s.user():
s += "This answer got Lean compile error:\n" + str(new_state) + "\n"
s += "Please try again by taking the Lean compiler feedback."
else:
return tactic, new_state
return None, None
def apply_tactic(server, state, goal_id, tactic):
try:
new_state = server.goal_tactic(state, goal_id=goal_id, tactic=tactic)
except ServerError as e:
return False, e
except TacticFailure as e:
return False, e
return True, new_state
def extract_code_from_llm_output(reply):
i = reply.find("```lean")
if i != -1:
reply = reply[i + 7:]
i = reply.find("```")
reply = reply[:i]
return reply
i = reply.find("```")
if i != -1:
reply = reply[i + 3:]
i = reply.find("```")
reply = reply[:i]
return reply
return reply
class TestServerSGL(unittest.TestCase):
def test_conv_calc_sgl(self):
n_trails = 5
sgl.set_default_backend(sgl.OpenAI("gpt-4"))
server = Server()
state0 = server.goal_start("∀ (a b: Nat), (b = 2) -> 1 + a + 1 = a + b")
print("==========state0============")
print(state0)
variables = [
Variable(name="a", t="Nat"),
Variable(name="b", t="Nat"),
Variable(name="h", t="b = 2"),
]
state1 = server.goal_tactic(state0, goal_id=0, tactic="intro a b h")
print("==========state1============")
print(state1)
state2 = server.goal_tactic(state1, goal_id=0, tactic=TacticCalc("1 + a + 1 = a + 1 + 1"))
print("==========state2============")
print(state2)
self.assertEqual(state2.goals, [
Goal(
variables,
target="1 + a + 1 = a + 1 + 1",
name='calc',
),
Goal(
variables,
target="a + 1 + 1 = a + b",
),
])
state3 = None
for i in range(n_trails):
print(f"===============trail {str(i)}============")
try:
state = select_tactic.run(server, state2, goal_id = 1)
state3 = state.ret_value
for m in state.messages():
print(m["role"], ":", m["content"])
print("\n-- new state --\n", state3)
break
except ServerError as e:
print(f"server error: {e}")
continue
state3 = server.goal_tactic(state2, goal_id=1, tactic=TacticCalc("_ = a + 2"))
print("==========state3============")
print(state3)
state4 = None
for i in range(n_trails):
print(f"===============trail {str(i)}============")
try:
state = select_tactic.run(server, state3, goal_id = 0)
state4 = state.ret_value
for m in state.messages():
print(m["role"], ":", m["content"])
print("\n-- new state --\n", state4)
break
except ServerError as e:
print(f"server error: {e}")
continue
state4 = server.goal_tactic(state3, goal_id=0, tactic="rw [Nat.add_assoc]")
print("==========state4============")
print(state4)
self.assertTrue(state4.is_solved)
def test_sglang_openai(self):
sgl.set_default_backend(sgl.OpenAI("gpt-4"))
print('\n----- Test sglang ---')
state = multi_turn_question.run(
question_1="What is the capital of the United States?",
question_2="List two local attractions.",
)
for m in state.messages():
print(m["role"], ":", m["content"])
print("\n-- answer_1 --\n", state["answer_1"])
if __name__ == '__main__':
unittest.main()