feat: Calc tactic
This commit is contained in:
parent
30c1fd894f
commit
535770bbd7
|
@ -29,10 +29,12 @@ structure GoalState where
|
|||
newMVars: SSet MVarId
|
||||
|
||||
-- Parent state metavariable source
|
||||
parentMVar: Option MVarId
|
||||
parentMVar?: Option MVarId
|
||||
|
||||
-- Existence of this field shows that we are currently in `conv` mode.
|
||||
convMVar: Option (MVarId × MVarId) := .none
|
||||
convMVar?: Option (MVarId × MVarId) := .none
|
||||
-- Previous RHS for calc, so we don't have to repeat it every time
|
||||
calcPrevRhs?: Option Expr := .none
|
||||
|
||||
protected def GoalState.create (expr: Expr): Elab.TermElabM GoalState := do
|
||||
-- May be necessary to immediately synthesise all metavariables if we need to leave the elaboration context.
|
||||
|
@ -48,10 +50,10 @@ protected def GoalState.create (expr: Expr): Elab.TermElabM GoalState := do
|
|||
savedState,
|
||||
root,
|
||||
newMVars := SSet.insert .empty root,
|
||||
parentMVar := .none,
|
||||
parentMVar? := .none,
|
||||
}
|
||||
protected def GoalState.isConv (state: GoalState): Bool :=
|
||||
state.convMVar.isSome
|
||||
state.convMVar?.isSome
|
||||
protected def GoalState.goals (state: GoalState): List MVarId :=
|
||||
state.savedState.tactic.goals
|
||||
protected def GoalState.mctx (state: GoalState): MetavarContext :=
|
||||
|
@ -136,7 +138,7 @@ protected def GoalState.tryTactic (state: GoalState) (goalId: Nat) (tactic: Stri
|
|||
state with
|
||||
savedState := nextSavedState
|
||||
newMVars := newMVarSet prevMCtx nextMCtx,
|
||||
parentMVar := .some goal,
|
||||
parentMVar? := .some goal,
|
||||
}
|
||||
|
||||
/-- Assumes elabM has already been restored. Assumes expr has already typechecked -/
|
||||
|
@ -174,7 +176,7 @@ protected def GoalState.assign (state: GoalState) (goal: MVarId) (expr: Expr):
|
|||
tactic := { goals := nextGoals }
|
||||
},
|
||||
newMVars,
|
||||
parentMVar := .some goal,
|
||||
parentMVar? := .some goal,
|
||||
}
|
||||
catch exception =>
|
||||
return .failure #[← exception.toMessageData.toString]
|
||||
|
@ -247,7 +249,7 @@ protected def GoalState.tryHave (state: GoalState) (goalId: Nat) (binderName: St
|
|||
tactic := { goals := nextGoals }
|
||||
},
|
||||
newMVars := nextGoals.toSSet,
|
||||
parentMVar := .some goal,
|
||||
parentMVar? := .some goal,
|
||||
}
|
||||
catch exception =>
|
||||
return .failure #[← exception.toMessageData.toString]
|
||||
|
@ -255,7 +257,7 @@ protected def GoalState.tryHave (state: GoalState) (goalId: Nat) (binderName: St
|
|||
/-- Enter conv tactic mode -/
|
||||
protected def GoalState.conv (state: GoalState) (goalId: Nat):
|
||||
Elab.TermElabM TacticResult := do
|
||||
if state.convMVar.isSome then
|
||||
if state.convMVar?.isSome then
|
||||
return .invalidAction "Already in conv state"
|
||||
let goal ← match state.savedState.tactic.goals.get? goalId with
|
||||
| .some goal => pure goal
|
||||
|
@ -277,8 +279,8 @@ protected def GoalState.conv (state: GoalState) (goalId: Nat):
|
|||
root := state.root,
|
||||
savedState := nextSavedState
|
||||
newMVars := newMVarSet prevMCtx nextMCtx,
|
||||
parentMVar := .some goal,
|
||||
convMVar := .some (convRhs, goal),
|
||||
parentMVar? := .some goal,
|
||||
convMVar? := .some (convRhs, goal),
|
||||
}
|
||||
catch exception =>
|
||||
return .failure #[← exception.toMessageData.toString]
|
||||
|
@ -286,15 +288,13 @@ protected def GoalState.conv (state: GoalState) (goalId: Nat):
|
|||
/-- Exit from `conv` mode. Resumes all goals before the mode starts and applys the conv -/
|
||||
protected def GoalState.convExit (state: GoalState):
|
||||
Elab.TermElabM TacticResult := do
|
||||
let (convRhs, convGoal) ← match state.convMVar with
|
||||
let (convRhs, convGoal) ← match state.convMVar? with
|
||||
| .some mvar => pure mvar
|
||||
| .none => return .invalidAction "Not in conv state"
|
||||
let tacticM : Elab.Tactic.TacticM Elab.Tactic.SavedState:= do
|
||||
-- Vide `Lean.Elab.Tactic.Conv.convert`
|
||||
state.savedState.restore
|
||||
|
||||
IO.println "Restored state"
|
||||
|
||||
-- Close all existing goals with `refl`
|
||||
for mvarId in (← Elab.Tactic.getGoals) do
|
||||
liftM <| mvarId.refl <|> mvarId.inferInstance <|> pure ()
|
||||
|
@ -302,7 +302,6 @@ protected def GoalState.convExit (state: GoalState):
|
|||
unless (← Elab.Tactic.getGoals).isEmpty do
|
||||
throwError "convert tactic failed, there are unsolved goals\n{Elab.goalsToMessageData (← Elab.Tactic.getGoals)}"
|
||||
|
||||
IO.println "Caching"
|
||||
Elab.Tactic.setGoals [convGoal]
|
||||
|
||||
let targetNew ← instantiateMVars (.mvar convRhs)
|
||||
|
@ -312,19 +311,89 @@ protected def GoalState.convExit (state: GoalState):
|
|||
MonadBacktrack.saveState
|
||||
try
|
||||
let nextSavedState ← tacticM { elaborator := .anonymous } |>.run' state.savedState.tactic
|
||||
IO.println "Finished caching"
|
||||
let nextMCtx := nextSavedState.term.meta.meta.mctx
|
||||
let prevMCtx := state.savedState.term.meta.meta.mctx
|
||||
return .success {
|
||||
root := state.root,
|
||||
savedState := nextSavedState
|
||||
newMVars := newMVarSet prevMCtx nextMCtx,
|
||||
parentMVar := .some convGoal,
|
||||
convMVar := .none
|
||||
parentMVar? := .some convGoal,
|
||||
convMVar? := .none
|
||||
}
|
||||
catch exception =>
|
||||
return .failure #[← exception.toMessageData.toString]
|
||||
|
||||
protected def GoalState.tryCalc (state: GoalState) (goalId: Nat) (pred: String):
|
||||
Elab.TermElabM TacticResult := do
|
||||
state.restoreElabM
|
||||
if state.convMVar?.isSome then
|
||||
return .invalidAction "Cannot initiate `calc` while in `conv` state"
|
||||
let goal ← match state.savedState.tactic.goals.get? goalId with
|
||||
| .some goal => pure goal
|
||||
| .none => return .indexError goalId
|
||||
let `(term|$pred) ← match Parser.runParserCategory
|
||||
(env := state.env)
|
||||
(catName := `term)
|
||||
(input := pred)
|
||||
(fileName := filename) with
|
||||
| .ok syn => pure syn
|
||||
| .error error => return .parseError error
|
||||
try
|
||||
goal.withContext do
|
||||
let target ← instantiateMVars (← goal.getDecl).type
|
||||
let tag := (← goal.getDecl).userName
|
||||
|
||||
let mut step ← Elab.Term.elabType <| ← do
|
||||
if let some prevRhs := state.calcPrevRhs? then
|
||||
Elab.Term.annotateFirstHoleWithType pred (← Meta.inferType prevRhs)
|
||||
else
|
||||
pure pred
|
||||
|
||||
let some (_, lhs, rhs) ← Elab.Term.getCalcRelation? step |
|
||||
throwErrorAt pred "invalid 'calc' step, relation expected{indentExpr step}"
|
||||
if let some prevRhs := state.calcPrevRhs? then
|
||||
unless (← Meta.isDefEqGuarded lhs prevRhs) do
|
||||
throwErrorAt pred "invalid 'calc' step, left-hand-side is{indentD m!"{lhs} : {← Meta.inferType lhs}"}\nprevious right-hand-side is{indentD m!"{prevRhs} : {← Meta.inferType prevRhs}"}" -- "
|
||||
|
||||
-- Creates a mvar to represent the proof that the calc tactic solves the
|
||||
-- current branch
|
||||
-- In the Lean `calc` tactic this is gobbled up by
|
||||
-- `withCollectingNewGoalsFrom`
|
||||
let mut proof ← Meta.mkFreshExprMVarAt (← getLCtx) (← Meta.getLocalInstances) step
|
||||
(userName := tag ++ `calc)
|
||||
let mvarBranch := proof.mvarId!
|
||||
|
||||
let calcPrevRhs? := Option.some rhs
|
||||
let mut proofType ← Meta.inferType proof
|
||||
let mut remainder := Option.none
|
||||
|
||||
-- The calc tactic either solves the main goal or leaves another relation.
|
||||
-- Replace the main goal, and save the new goal if necessary
|
||||
if ¬(← Meta.isDefEq proofType target) then
|
||||
let rec throwFailed :=
|
||||
throwError "'calc' tactic failed, has type{indentExpr proofType}\nbut it is expected to have type{indentExpr target}"
|
||||
let some (_, _, rhs) ← Elab.Term.getCalcRelation? proofType | throwFailed
|
||||
let some (r, _, rhs') ← Elab.Term.getCalcRelation? target | throwFailed
|
||||
let lastStep := mkApp2 r rhs rhs'
|
||||
let lastStepGoal ← Meta.mkFreshExprSyntheticOpaqueMVar lastStep tag
|
||||
(proof, proofType) ← Elab.Term.mkCalcTrans proof proofType lastStepGoal lastStep
|
||||
unless (← Meta.isDefEq proofType target) do throwFailed
|
||||
remainder := .some lastStepGoal.mvarId!
|
||||
goal.assign proof
|
||||
|
||||
let goals := [ mvarBranch ] ++ remainder.toList
|
||||
return .success {
|
||||
root := state.root,
|
||||
savedState := {
|
||||
term := ← MonadBacktrack.saveState,
|
||||
tactic := { goals },
|
||||
},
|
||||
newMVars := goals.toSSet,
|
||||
parentMVar? := .some goal,
|
||||
calcPrevRhs?
|
||||
}
|
||||
catch exception =>
|
||||
return .failure #[← exception.toMessageData.toString]
|
||||
|
||||
|
||||
protected def GoalState.focus (state: GoalState) (goalId: Nat): Option GoalState := do
|
||||
|
@ -377,7 +446,7 @@ protected def GoalState.rootExpr? (goalState: GoalState): Option Expr := do
|
|||
assert! goalState.goals.isEmpty
|
||||
return expr
|
||||
protected def GoalState.parentExpr? (goalState: GoalState): Option Expr := do
|
||||
let parent ← goalState.parentMVar
|
||||
let parent ← goalState.parentMVar?
|
||||
let expr := goalState.mctx.eAssignment.find! parent
|
||||
let (expr, _) := instantiateMVarsCore (mctx := goalState.mctx) (e := expr)
|
||||
return expr
|
||||
|
|
|
@ -249,7 +249,7 @@ protected def GoalState.serializeGoals
|
|||
MetaM (Array Protocol.Goal):= do
|
||||
state.restoreMetaM
|
||||
let goals := state.goals.toArray
|
||||
let parentDecl? := parent.bind (λ parentState => parentState.mctx.findDecl? state.parentMVar.get!)
|
||||
let parentDecl? := parent.bind (λ parentState => parentState.mctx.findDecl? state.parentMVar?.get!)
|
||||
goals.mapM fun goal => do
|
||||
match state.mctx.findDecl? goal with
|
||||
| .some mvarDecl =>
|
||||
|
|
|
@ -479,36 +479,73 @@ def test_conv: TestM Unit := do
|
|||
let free := [("a", "Nat"), ("b", "Nat"), ("c1", "Nat"), ("c2", "Nat"), ("h", h)] ++ free
|
||||
buildGoal free target
|
||||
|
||||
example : ∀ (a: Nat), 1 + a + 1 = a + 2 := by
|
||||
intro a
|
||||
calc 1 + a + 1 = a + 1 + 1 := by conv =>
|
||||
rhs
|
||||
rw [Nat.add_comm]
|
||||
_ = a + 2 := by rw [Nat.add_assoc]
|
||||
example : ∀ (a b c d: Nat), a + b = b + c → b + c = c + d → a + b = c + d := by
|
||||
intro a b c d h1 h2
|
||||
calc a + b = b + c := by apply h1
|
||||
_ = c + d := by apply h2
|
||||
|
||||
def test_calc: TestM Unit := do
|
||||
let state? ← startProof (.expr "∀ (a: Nat), 1 + a + 1 = a + 2")
|
||||
let state? ← startProof (.expr "∀ (a b c d: Nat), a + b = b + c → b + c = c + d → a + b = c + d")
|
||||
let state0 ← match state? with
|
||||
| .some state => pure state
|
||||
| .none => do
|
||||
addTest $ assertUnreachable "Goal could not parse"
|
||||
return ()
|
||||
let tactic := "intro a"
|
||||
let tactic := "intro a b c d h1 h2"
|
||||
let state1 ← match ← state0.tryTactic (goalId := 0) (tactic := tactic) with
|
||||
| .success state => pure state
|
||||
| other => do
|
||||
addTest $ assertUnreachable $ other.toString
|
||||
return ()
|
||||
addTest $ LSpec.check tactic ((← state1.serializeGoals (options := ← read)).map (·.devolatilize) =
|
||||
#[buildGoal [("a", "Nat")] "1 + a + 1 = a + 2"])
|
||||
let tactic := "calc"
|
||||
let state2 ← match ← state1.tryTactic (goalId := 0) (tactic := tactic) with
|
||||
#[interiorGoal [] "a + b = c + d"])
|
||||
let pred := "a + b = b + c"
|
||||
let state2 ← match ← state1.tryCalc (goalId := 0) (pred := pred) with
|
||||
| .success state => pure state
|
||||
| other => do
|
||||
addTest $ assertUnreachable $ other.toString
|
||||
return ()
|
||||
addTest $ LSpec.check tactic ((← state1.serializeGoals (options := ← read)).map (·.devolatilize) =
|
||||
#[buildGoal [("a", "Nat")] "1 + a + 1 = a + 2"])
|
||||
addTest $ LSpec.check s!"calc {pred} := _" ((← state2.serializeGoals (options := ← read)).map (·.devolatilize) =
|
||||
#[
|
||||
interiorGoal [] "a + b = b + c" (.some "calc"),
|
||||
interiorGoal [] "b + c = c + d"
|
||||
])
|
||||
|
||||
let tactic := "apply h1"
|
||||
let state2m ← match ← state2.tryTactic (goalId := 0) (tactic := tactic) with
|
||||
| .success state => pure state
|
||||
| other => do
|
||||
addTest $ assertUnreachable $ other.toString
|
||||
return ()
|
||||
let state3 ← match state2m.continue state2 with
|
||||
| .ok state => pure state
|
||||
| .error e => do
|
||||
addTest $ expectationFailure "continue" e
|
||||
return ()
|
||||
let pred := "_ = c + d"
|
||||
let state4 ← match ← state3.tryCalc (goalId := 0) (pred := pred) with
|
||||
| .success state => pure state
|
||||
| other => do
|
||||
addTest $ assertUnreachable $ other.toString
|
||||
return ()
|
||||
addTest $ LSpec.check s!"calc {pred} := _" ((← state4.serializeGoals (options := ← read)).map (·.devolatilize) =
|
||||
#[
|
||||
interiorGoal [] "b + c = c + d" (.some "calc")
|
||||
])
|
||||
let tactic := "apply h2"
|
||||
let state4m ← match ← state4.tryTactic (goalId := 0) (tactic := tactic) with
|
||||
| .success state => pure state
|
||||
| other => do
|
||||
addTest $ assertUnreachable $ other.toString
|
||||
return ()
|
||||
addTest $ LSpec.test "(4m root)" state4m.rootExpr?.isSome
|
||||
|
||||
|
||||
where
|
||||
interiorGoal (free: List (String × String)) (target: String) (userName?: Option String := .none) :=
|
||||
let free := [("a", "Nat"), ("b", "Nat"), ("c", "Nat"), ("d", "Nat"),
|
||||
("h1", "a + b = b + c"), ("h2", "b + c = c + d")] ++ free
|
||||
buildGoal free target userName?
|
||||
|
||||
def suite (env: Environment): List (String × IO LSpec.TestSeq) :=
|
||||
let tests := [
|
||||
|
|
Loading…
Reference in New Issue