feat: Conv tactic functions

This commit is contained in:
Leni Aniva 2024-04-08 12:26:22 -07:00
parent 7af24a4f0a
commit 63e64a1e9f
4 changed files with 171 additions and 50 deletions

View File

@ -140,6 +140,8 @@ def execute (command: Protocol.Command): MainM Lean.Json := do
return .ok { parseError? := .some message } return .ok { parseError? := .some message }
| .ok (.indexError goalId) => | .ok (.indexError goalId) =>
return .error $ errorIndex s!"Invalid goal id index {goalId}" return .error $ errorIndex s!"Invalid goal id index {goalId}"
| .ok (.invalidAction message) =>
return .error $ errorI "invalid" message
| .ok (.failure messages) => | .ok (.failure messages) =>
return .ok { tacticErrors? := .some messages } return .ok { tacticErrors? := .some messages }
goal_continue (args: Protocol.GoalContinue): MainM (CR Protocol.GoalContinueResult) := do goal_continue (args: Protocol.GoalContinue): MainM (CR Protocol.GoalContinueResult) := do

View File

@ -31,6 +31,9 @@ structure GoalState where
-- Parent state metavariable source -- Parent state metavariable source
parentMVar: Option MVarId parentMVar: Option MVarId
-- Existence of this field shows that we are currently in `conv` mode.
convMVar: Option (MVarId × MVarId × List MVarId) := .none
protected def GoalState.create (expr: Expr): Elab.TermElabM GoalState := do protected def GoalState.create (expr: Expr): Elab.TermElabM GoalState := do
-- May be necessary to immediately synthesise all metavariables if we need to leave the elaboration context. -- May be necessary to immediately synthesise all metavariables if we need to leave the elaboration context.
-- See https://leanprover.zulipchat.com/#narrow/stream/270676-lean4/topic/Unknown.20universe.20metavariable/near/360130070 -- See https://leanprover.zulipchat.com/#narrow/stream/270676-lean4/topic/Unknown.20universe.20metavariable/near/360130070
@ -100,6 +103,8 @@ inductive TacticResult where
| parseError (message: String) | parseError (message: String)
-- The goal index is out of bounds -- The goal index is out of bounds
| indexError (goalId: Nat) | indexError (goalId: Nat)
-- The given action cannot be executed in the state
| invalidAction (message: String)
/-- Execute tactic on given state -/ /-- Execute tactic on given state -/
protected def GoalState.tryTactic (state: GoalState) (goalId: Nat) (tactic: String): protected def GoalState.tryTactic (state: GoalState) (goalId: Nat) (tactic: String):
@ -122,11 +127,11 @@ protected def GoalState.tryTactic (state: GoalState) (goalId: Nat) (tactic: Stri
| .ok nextSavedState => | .ok nextSavedState =>
-- Assert that the definition of metavariables are the same -- Assert that the definition of metavariables are the same
let nextMCtx := nextSavedState.term.meta.meta.mctx let nextMCtx := nextSavedState.term.meta.meta.mctx
let prevMCtx := state.savedState.term.meta.meta.mctx let prevMCtx := state.mctx
-- Generate a list of mvarIds that exist in the parent state; Also test the -- Generate a list of mvarIds that exist in the parent state; Also test the
-- assertion that the types have not changed on any mvars. -- assertion that the types have not changed on any mvars.
return .success { return .success {
root := state.root, state with
savedState := nextSavedState savedState := nextSavedState
newMVars := newMVarSet prevMCtx nextMCtx, newMVars := newMVarSet prevMCtx nextMCtx,
parentMVar := .some goal, parentMVar := .some goal,
@ -146,7 +151,7 @@ protected def GoalState.assign (state: GoalState) (goal: MVarId) (expr: Expr):
return .some s!"{← Meta.ppExpr expr} : {← Meta.ppExpr exprType} != {← Meta.ppExpr goalType}" return .some s!"{← Meta.ppExpr expr} : {← Meta.ppExpr exprType} != {← Meta.ppExpr goalType}"
) )
if let .some error := error? then if let .some error := error? then
return .failure #["Type unification failed", error] return .parseError error
goal.checkNotAssigned `GoalState.assign goal.checkNotAssigned `GoalState.assign
goal.assign expr goal.assign expr
if (← getThe Core.State).messages.hasErrors then if (← getThe Core.State).messages.hasErrors then
@ -246,35 +251,45 @@ protected def GoalState.tryHave (state: GoalState) (goalId: Nat) (binderName: St
return .failure #[← exception.toMessageData.toString] return .failure #[← exception.toMessageData.toString]
/-- Enter conv tactic mode -/ /-- Enter conv tactic mode -/
protected def GoalState.tryConv (state: GoalState) (goalId: Nat): protected def GoalState.conv (state: GoalState) (goalId: Nat):
Elab.TermElabM TacticResult := do Elab.TermElabM TacticResult := do
if state.convMVar.isSome then
return .invalidAction "Already in conv state"
let goal ← match state.savedState.tactic.goals.get? goalId with let goal ← match state.savedState.tactic.goals.get? goalId with
| .some goal => pure goal | .some goal => pure goal
| .none => return .indexError goalId | .none => return .indexError goalId
let tacticM : Elab.Tactic.TacticM Elab.Tactic.SavedState:= do let tacticM : Elab.Tactic.TacticM (Elab.Tactic.SavedState × MVarId) := do
state.restoreTacticM goal state.restoreTacticM goal
-- TODO: Fail if this is already in conv -- TODO: Fail if this is already in conv
-- See Lean.Elab.Tactic.Conv.convTarget -- See Lean.Elab.Tactic.Conv.convTarget
Elab.Tactic.withMainContext do let convMVar ← Elab.Tactic.withMainContext do
-- TODO: Memorize this `rhs` as a conv resultant goal -- TODO: Memorize this `rhs` as a conv resultant goal
let (rhs, newGoal) ← Elab.Tactic.Conv.mkConvGoalFor (← Elab.Tactic.getMainTarget) let (rhs, newGoal) ← Elab.Tactic.Conv.mkConvGoalFor (← Elab.Tactic.getMainTarget)
Elab.Tactic.setGoals [newGoal.mvarId!] Elab.Tactic.setGoals [newGoal.mvarId!]
--Elab.Tactic.liftMetaTactic1 fun mvarId => mvarId.replaceTargetEq rhs proof pure rhs.mvarId!
MonadBacktrack.saveState return (← MonadBacktrack.saveState, convMVar)
let nextSavedState ← tacticM { elaborator := .anonymous } |>.run' state.savedState.tactic try
let prevMCtx := state.savedState.term.meta.meta.mctx let (nextSavedState, convRhs) ← tacticM { elaborator := .anonymous } |>.run' state.savedState.tactic
let nextMCtx := nextSavedState.term.meta.meta.mctx let prevMCtx := state.mctx
return .success { let nextMCtx := nextSavedState.term.meta.meta.mctx
root := state.root, return .success {
savedState := nextSavedState root := state.root,
newMVars := newMVarSet prevMCtx nextMCtx, savedState := nextSavedState
parentMVar := .some goal, newMVars := newMVarSet prevMCtx nextMCtx,
} parentMVar := .some goal,
convMVar := .some (convRhs, goal, state.goals),
}
catch exception =>
return .failure #[← exception.toMessageData.toString]
/-- Execute a tactic in conv mode -/
protected def GoalState.tryConvTactic (state: GoalState) (goalId: Nat) (convTactic: String): protected def GoalState.tryConvTactic (state: GoalState) (goalId: Nat) (convTactic: String):
Elab.TermElabM TacticResult := do Elab.TermElabM TacticResult := do
let _ ← match state.convMVar with
| .some mvar => pure mvar
| .none => return .invalidAction "Not in conv state"
let goal ← match state.savedState.tactic.goals.get? goalId with let goal ← match state.savedState.tactic.goals.get? goalId with
| .some goal => pure goal | .some goal => pure goal
| .none => return .indexError goalId | .none => return .indexError goalId
@ -289,15 +304,60 @@ protected def GoalState.tryConvTactic (state: GoalState) (goalId: Nat) (convTact
state.restoreTacticM goal state.restoreTacticM goal
Elab.Tactic.evalTactic convTactic Elab.Tactic.evalTactic convTactic
MonadBacktrack.saveState MonadBacktrack.saveState
let nextSavedState ← tacticM { elaborator := .anonymous } |>.run' state.savedState.tactic try
let nextMCtx := nextSavedState.term.meta.meta.mctx let prevMCtx := state.mctx
let prevMCtx := state.savedState.term.meta.meta.mctx let nextSavedState ← tacticM { elaborator := .anonymous } |>.run' state.savedState.tactic
return .success { let nextMCtx := nextSavedState.term.meta.meta.mctx
root := state.root, return .success {
savedState := nextSavedState state with
newMVars := newMVarSet prevMCtx nextMCtx, savedState := nextSavedState
parentMVar := .some goal, newMVars := newMVarSet prevMCtx nextMCtx,
} parentMVar := .some goal,
}
catch exception =>
return .failure #[← exception.toMessageData.toString]
protected def GoalState.convExit (state: GoalState):
Elab.TermElabM TacticResult := do
let (convRhs, convGoal, savedGoals) ← match state.convMVar with
| .some mvar => pure mvar
| .none => return .invalidAction "Not in conv state"
let tacticM : Elab.Tactic.TacticM Elab.Tactic.SavedState:= do
-- Vide `Lean.Elab.Tactic.Conv.convert`
state.savedState.restore
IO.println "Restored state"
-- Close all existing goals with `refl`
for mvarId in (← Elab.Tactic.getGoals) do
liftM <| mvarId.refl <|> mvarId.inferInstance <|> pure ()
Elab.Tactic.pruneSolvedGoals
unless (← Elab.Tactic.getGoals).isEmpty do
throwError "convert tactic failed, there are unsolved goals\n{Elab.goalsToMessageData (← Elab.Tactic.getGoals)}"
IO.println "Caching"
Elab.Tactic.setGoals savedGoals
let targetNew ← instantiateMVars (.mvar convRhs)
let proof ← instantiateMVars (.mvar convGoal)
Elab.Tactic.liftMetaTactic1 fun mvarId => mvarId.replaceTargetEq targetNew proof
MonadBacktrack.saveState
try
let nextSavedState ← tacticM { elaborator := .anonymous } |>.run' state.savedState.tactic
IO.println "Finished caching"
let nextMCtx := nextSavedState.term.meta.meta.mctx
let prevMCtx := state.savedState.term.meta.meta.mctx
return .success {
root := state.root,
savedState := nextSavedState
newMVars := newMVarSet prevMCtx nextMCtx,
parentMVar := .some convGoal,
convMVar := .none
}
catch exception =>
return .failure #[← exception.toMessageData.toString]
/-- /--

View File

@ -37,6 +37,7 @@ def TacticResult.toString : TacticResult → String
s!".failure {messages}" s!".failure {messages}"
| .parseError error => s!".parseError {error}" | .parseError error => s!".parseError {error}"
| .indexError index => s!".indexError {index}" | .indexError index => s!".indexError {index}"
| .invalidAction error => s!".invalidAction {error}"
namespace Test namespace Test

View File

@ -361,47 +361,48 @@ def test_have: TestM Unit := do
addTest $ LSpec.check "(4 root)" state4.rootExpr?.isSome addTest $ LSpec.check "(4 root)" state4.rootExpr?.isSome
example : ∀ (a b c: Nat), (a + b) + c = (b + a) + c := by example : ∀ (a b c1 c2: Nat), (b + a) + c1 = (b + a) + c2 → (a + b) + c1 = (b + a) + c2 := by
intro a b c intro a b c1 c2 h
conv => conv =>
lhs lhs
congr congr
rw [Nat.add_comm] . rw [Nat.add_comm]
rfl . rfl
exact h
def test_conv: TestM Unit := do def test_conv: TestM Unit := do
let state? ← startProof (.expr "∀ (a b c: Nat), (a + b) + c = (b + a) + c") let state? ← startProof (.expr "∀ (a b c1 c2: Nat), (b + a) + c1 = (b + a) + c2 → (a + b) + c1 = (b + a) + c2")
let state0 ← match state? with let state0 ← match state? with
| .some state => pure state | .some state => pure state
| .none => do | .none => do
addTest $ assertUnreachable "Goal could not parse" addTest $ assertUnreachable "Goal could not parse"
return () return ()
let tactic := "intro a b c"
let tactic := "intro a b c1 c2 h"
let state1 ← match ← state0.tryTactic (goalId := 0) (tactic := tactic) with let state1 ← match ← state0.tryTactic (goalId := 0) (tactic := tactic) with
| .success state => pure state | .success state => pure state
| other => do | other => do
addTest $ assertUnreachable $ other.toString addTest $ assertUnreachable $ other.toString
return () return ()
addTest $ LSpec.check tactic ((← state1.serializeGoals (options := ← read)).map (·.devolatilize) = addTest $ LSpec.check tactic ((← state1.serializeGoals (options := ← read)).map (·.devolatilize) =
#[buildGoal [("a", "Nat"), ("b", "Nat"), ("c", "Nat")] "a + b + c = b + a + c"]) #[interiorGoal [] "a + b + c1 = b + a + c2"])
-- This solves the state in one-shot let state2 ← match ← state1.conv (goalId := 0) with
let tactic := "conv => { lhs; congr; rw [Nat.add_comm]; rfl }"
let stateT ← match ← state1.tryTactic (goalId := 0) (tactic := tactic) with
| .success state => pure state
| other => do
addTest $ assertUnreachable $ other.toString
return ()
addTest $ LSpec.check tactic ((← stateT.serializeGoals (options := ← read)).map (·.devolatilize) =
#[])
let state2 ← match ← state1.tryConv (goalId := 0) with
| .success state => pure state | .success state => pure state
| other => do | other => do
addTest $ assertUnreachable $ other.toString addTest $ assertUnreachable $ other.toString
return () return ()
addTest $ LSpec.check "conv => ..." ((← state2.serializeGoals (options := ← read)).map (·.devolatilize) = addTest $ LSpec.check "conv => ..." ((← state2.serializeGoals (options := ← read)).map (·.devolatilize) =
#[{ buildGoal [("a", "Nat"), ("b", "Nat"), ("c", "Nat")] "a + b + c = b + a + c" with isConversion := true }]) #[{ interiorGoal [] "a + b + c1 = b + a + c2" with isConversion := true }])
let convTactic := "rhs"
let state3R ← match ← state2.tryConvTactic (goalId := 0) (convTactic := convTactic) with
| .success state => pure state
| other => do
addTest $ assertUnreachable $ other.toString
return ()
addTest $ LSpec.check s!" {convTactic} (discard)" ((← state3R.serializeGoals (options := ← read)).map (·.devolatilize) =
#[{ interiorGoal [] "b + a + c2" with isConversion := true }])
let convTactic := "lhs" let convTactic := "lhs"
let state3L ← match ← state2.tryConvTactic (goalId := 0) (convTactic := convTactic) with let state3L ← match ← state2.tryConvTactic (goalId := 0) (convTactic := convTactic) with
@ -410,16 +411,73 @@ def test_conv: TestM Unit := do
addTest $ assertUnreachable $ other.toString addTest $ assertUnreachable $ other.toString
return () return ()
addTest $ LSpec.check s!" {convTactic}" ((← state3L.serializeGoals (options := ← read)).map (·.devolatilize) = addTest $ LSpec.check s!" {convTactic}" ((← state3L.serializeGoals (options := ← read)).map (·.devolatilize) =
#[{ buildGoal [("a", "Nat"), ("b", "Nat"), ("c", "Nat")] "a + b + c" with isConversion := true }]) #[{ interiorGoal [] "a + b + c1" with isConversion := true }])
let convTactic := "rhs" let convTactic := "congr"
let state3R ← match ← state2.tryConvTactic (goalId := 0) (convTactic := convTactic) with let state4 ← match ← state3L.tryConvTactic (goalId := 0) (convTactic := convTactic) with
| .success state => pure state | .success state => pure state
| other => do | other => do
addTest $ assertUnreachable $ other.toString addTest $ assertUnreachable $ other.toString
return () return ()
addTest $ LSpec.check s!" {convTactic}" ((← state3R.serializeGoals (options := ← read)).map (·.devolatilize) = addTest $ LSpec.check s!" {convTactic}" ((← state4.serializeGoals (options := ← read)).map (·.devolatilize) =
#[{ buildGoal [("a", "Nat"), ("b", "Nat"), ("c", "Nat")] "b + a + c" with isConversion := true }]) #[
{ interiorGoal [] "a + b" with isConversion := true, userName? := .some "a" },
{ interiorGoal [] "c1" with isConversion := true, userName? := .some "a" }
])
let convTactic := "rw [Nat.add_comm]"
let state5_1 ← match ← state4.tryConvTactic (goalId := 0) (convTactic := convTactic) with
| .success state => pure state
| other => do
addTest $ assertUnreachable $ other.toString
return ()
addTest $ LSpec.check s!" · {convTactic}" ((← state5_1.serializeGoals (options := ← read)).map (·.devolatilize) =
#[{ interiorGoal [] "b + a" with isConversion := true, userName? := .some "a" }])
let convTactic := "rfl"
let state6_1 ← match ← state5_1.tryConvTactic (goalId := 0) (convTactic := convTactic) with
| .success state => pure state
| other => do
addTest $ assertUnreachable $ other.toString
return ()
addTest $ LSpec.check s!" {convTactic}" ((← state6_1.serializeGoals (options := ← read)).map (·.devolatilize) =
#[])
let state4_1 ← match state6_1.continue state4 with
| .ok state => pure state
| .error e => do
addTest $ expectationFailure "continue" e
return ()
let convTactic := "rfl"
let state6 ← match ← state4_1.tryConvTactic (goalId := 0) (convTactic := convTactic) with
| .success state => pure state
| other => do
addTest $ assertUnreachable $ other.toString
return ()
addTest $ LSpec.check s!" · {convTactic}" ((← state6.serializeGoals (options := ← read)).map (·.devolatilize) =
#[])
let state1_1 ← match ← state6.convExit with
| .success state => pure state
| other => do
addTest $ assertUnreachable $ other.toString
return ()
let tactic := "exact h"
let stateF ← match ← state1_1.tryTactic (goalId := 0) (tactic := tactic) with
| .success state => pure state
| other => do
addTest $ assertUnreachable $ other.toString
return ()
addTest $ LSpec.check tactic ((← stateF.serializeGoals (options := ← read)).map (·.devolatilize) =
#[])
where
h := "b + a + c1 = b + a + c2"
interiorGoal (free: List (String × String)) (target: String) :=
let free := [("a", "Nat"), ("b", "Nat"), ("c1", "Nat"), ("c2", "Nat"), ("h", h)] ++ free
buildGoal free target
example : ∀ (a: Nat), 1 + a + 1 = a + 2 := by example : ∀ (a: Nat), 1 + a + 1 = a + 2 := by
intro a intro a