From 403d92692e056b730c98cfca615a910463ec7399 Mon Sep 17 00:00:00 2001 From: Leni Aniva Date: Thu, 11 Apr 2024 14:59:55 -0700 Subject: [PATCH] feat: Calc tactic --- Pantograph/Goal.lean | 105 ++++++++++++++++++++++++++++++++++------- Pantograph/Serial.lean | 2 +- Test/Proofs.lean | 63 ++++++++++++++++++++----- 3 files changed, 138 insertions(+), 32 deletions(-) diff --git a/Pantograph/Goal.lean b/Pantograph/Goal.lean index 7e5c0c2..a6d99bc 100644 --- a/Pantograph/Goal.lean +++ b/Pantograph/Goal.lean @@ -29,10 +29,12 @@ structure GoalState where newMVars: SSet MVarId -- Parent state metavariable source - parentMVar: Option MVarId + parentMVar?: Option MVarId -- Existence of this field shows that we are currently in `conv` mode. - convMVar: Option (MVarId × MVarId) := .none + convMVar?: Option (MVarId × MVarId) := .none + -- Previous RHS for calc, so we don't have to repeat it every time + calcPrevRhs?: Option Expr := .none protected def GoalState.create (expr: Expr): Elab.TermElabM GoalState := do -- May be necessary to immediately synthesise all metavariables if we need to leave the elaboration context. @@ -48,10 +50,10 @@ protected def GoalState.create (expr: Expr): Elab.TermElabM GoalState := do savedState, root, newMVars := SSet.insert .empty root, - parentMVar := .none, + parentMVar? := .none, } protected def GoalState.isConv (state: GoalState): Bool := - state.convMVar.isSome + state.convMVar?.isSome protected def GoalState.goals (state: GoalState): List MVarId := state.savedState.tactic.goals protected def GoalState.mctx (state: GoalState): MetavarContext := @@ -136,7 +138,7 @@ protected def GoalState.tryTactic (state: GoalState) (goalId: Nat) (tactic: Stri state with savedState := nextSavedState newMVars := newMVarSet prevMCtx nextMCtx, - parentMVar := .some goal, + parentMVar? := .some goal, } /-- Assumes elabM has already been restored. Assumes expr has already typechecked -/ @@ -174,7 +176,7 @@ protected def GoalState.assign (state: GoalState) (goal: MVarId) (expr: Expr): tactic := { goals := nextGoals } }, newMVars, - parentMVar := .some goal, + parentMVar? := .some goal, } catch exception => return .failure #[← exception.toMessageData.toString] @@ -247,7 +249,7 @@ protected def GoalState.tryHave (state: GoalState) (goalId: Nat) (binderName: St tactic := { goals := nextGoals } }, newMVars := nextGoals.toSSet, - parentMVar := .some goal, + parentMVar? := .some goal, } catch exception => return .failure #[← exception.toMessageData.toString] @@ -255,7 +257,7 @@ protected def GoalState.tryHave (state: GoalState) (goalId: Nat) (binderName: St /-- Enter conv tactic mode -/ protected def GoalState.conv (state: GoalState) (goalId: Nat): Elab.TermElabM TacticResult := do - if state.convMVar.isSome then + if state.convMVar?.isSome then return .invalidAction "Already in conv state" let goal ← match state.savedState.tactic.goals.get? goalId with | .some goal => pure goal @@ -277,8 +279,8 @@ protected def GoalState.conv (state: GoalState) (goalId: Nat): root := state.root, savedState := nextSavedState newMVars := newMVarSet prevMCtx nextMCtx, - parentMVar := .some goal, - convMVar := .some (convRhs, goal), + parentMVar? := .some goal, + convMVar? := .some (convRhs, goal), } catch exception => return .failure #[← exception.toMessageData.toString] @@ -286,15 +288,13 @@ protected def GoalState.conv (state: GoalState) (goalId: Nat): /-- Exit from `conv` mode. Resumes all goals before the mode starts and applys the conv -/ protected def GoalState.convExit (state: GoalState): Elab.TermElabM TacticResult := do - let (convRhs, convGoal) ← match state.convMVar with + let (convRhs, convGoal) ← match state.convMVar? with | .some mvar => pure mvar | .none => return .invalidAction "Not in conv state" let tacticM : Elab.Tactic.TacticM Elab.Tactic.SavedState:= do -- Vide `Lean.Elab.Tactic.Conv.convert` state.savedState.restore - IO.println "Restored state" - -- Close all existing goals with `refl` for mvarId in (← Elab.Tactic.getGoals) do liftM <| mvarId.refl <|> mvarId.inferInstance <|> pure () @@ -302,7 +302,6 @@ protected def GoalState.convExit (state: GoalState): unless (← Elab.Tactic.getGoals).isEmpty do throwError "convert tactic failed, there are unsolved goals\n{Elab.goalsToMessageData (← Elab.Tactic.getGoals)}" - IO.println "Caching" Elab.Tactic.setGoals [convGoal] let targetNew ← instantiateMVars (.mvar convRhs) @@ -312,19 +311,89 @@ protected def GoalState.convExit (state: GoalState): MonadBacktrack.saveState try let nextSavedState ← tacticM { elaborator := .anonymous } |>.run' state.savedState.tactic - IO.println "Finished caching" let nextMCtx := nextSavedState.term.meta.meta.mctx let prevMCtx := state.savedState.term.meta.meta.mctx return .success { root := state.root, savedState := nextSavedState newMVars := newMVarSet prevMCtx nextMCtx, - parentMVar := .some convGoal, - convMVar := .none + parentMVar? := .some convGoal, + convMVar? := .none } catch exception => return .failure #[← exception.toMessageData.toString] +protected def GoalState.tryCalc (state: GoalState) (goalId: Nat) (pred: String): + Elab.TermElabM TacticResult := do + state.restoreElabM + if state.convMVar?.isSome then + return .invalidAction "Cannot initiate `calc` while in `conv` state" + let goal ← match state.savedState.tactic.goals.get? goalId with + | .some goal => pure goal + | .none => return .indexError goalId + let `(term|$pred) ← match Parser.runParserCategory + (env := state.env) + (catName := `term) + (input := pred) + (fileName := filename) with + | .ok syn => pure syn + | .error error => return .parseError error + try + goal.withContext do + let target ← instantiateMVars (← goal.getDecl).type + let tag := (← goal.getDecl).userName + + let mut step ← Elab.Term.elabType <| ← do + if let some prevRhs := state.calcPrevRhs? then + Elab.Term.annotateFirstHoleWithType pred (← Meta.inferType prevRhs) + else + pure pred + + let some (_, lhs, rhs) ← Elab.Term.getCalcRelation? step | + throwErrorAt pred "invalid 'calc' step, relation expected{indentExpr step}" + if let some prevRhs := state.calcPrevRhs? then + unless (← Meta.isDefEqGuarded lhs prevRhs) do + throwErrorAt pred "invalid 'calc' step, left-hand-side is{indentD m!"{lhs} : {← Meta.inferType lhs}"}\nprevious right-hand-side is{indentD m!"{prevRhs} : {← Meta.inferType prevRhs}"}" -- " + + -- Creates a mvar to represent the proof that the calc tactic solves the + -- current branch + -- In the Lean `calc` tactic this is gobbled up by + -- `withCollectingNewGoalsFrom` + let mut proof ← Meta.mkFreshExprMVarAt (← getLCtx) (← Meta.getLocalInstances) step + (userName := tag ++ `calc) + let mvarBranch := proof.mvarId! + + let calcPrevRhs? := Option.some rhs + let mut proofType ← Meta.inferType proof + let mut remainder := Option.none + + -- The calc tactic either solves the main goal or leaves another relation. + -- Replace the main goal, and save the new goal if necessary + if ¬(← Meta.isDefEq proofType target) then + let rec throwFailed := + throwError "'calc' tactic failed, has type{indentExpr proofType}\nbut it is expected to have type{indentExpr target}" + let some (_, _, rhs) ← Elab.Term.getCalcRelation? proofType | throwFailed + let some (r, _, rhs') ← Elab.Term.getCalcRelation? target | throwFailed + let lastStep := mkApp2 r rhs rhs' + let lastStepGoal ← Meta.mkFreshExprSyntheticOpaqueMVar lastStep tag + (proof, proofType) ← Elab.Term.mkCalcTrans proof proofType lastStepGoal lastStep + unless (← Meta.isDefEq proofType target) do throwFailed + remainder := .some lastStepGoal.mvarId! + goal.assign proof + + let goals := [ mvarBranch ] ++ remainder.toList + return .success { + root := state.root, + savedState := { + term := ← MonadBacktrack.saveState, + tactic := { goals }, + }, + newMVars := goals.toSSet, + parentMVar? := .some goal, + calcPrevRhs? + } + catch exception => + return .failure #[← exception.toMessageData.toString] protected def GoalState.focus (state: GoalState) (goalId: Nat): Option GoalState := do @@ -377,7 +446,7 @@ protected def GoalState.rootExpr? (goalState: GoalState): Option Expr := do assert! goalState.goals.isEmpty return expr protected def GoalState.parentExpr? (goalState: GoalState): Option Expr := do - let parent ← goalState.parentMVar + let parent ← goalState.parentMVar? let expr := goalState.mctx.eAssignment.find! parent let (expr, _) := instantiateMVarsCore (mctx := goalState.mctx) (e := expr) return expr diff --git a/Pantograph/Serial.lean b/Pantograph/Serial.lean index 57df5de..f975f76 100644 --- a/Pantograph/Serial.lean +++ b/Pantograph/Serial.lean @@ -249,7 +249,7 @@ protected def GoalState.serializeGoals MetaM (Array Protocol.Goal):= do state.restoreMetaM let goals := state.goals.toArray - let parentDecl? := parent.bind (λ parentState => parentState.mctx.findDecl? state.parentMVar.get!) + let parentDecl? := parent.bind (λ parentState => parentState.mctx.findDecl? state.parentMVar?.get!) goals.mapM fun goal => do match state.mctx.findDecl? goal with | .some mvarDecl => diff --git a/Test/Proofs.lean b/Test/Proofs.lean index 7a23290..9ede63e 100644 --- a/Test/Proofs.lean +++ b/Test/Proofs.lean @@ -479,36 +479,73 @@ def test_conv: TestM Unit := do let free := [("a", "Nat"), ("b", "Nat"), ("c1", "Nat"), ("c2", "Nat"), ("h", h)] ++ free buildGoal free target -example : ∀ (a: Nat), 1 + a + 1 = a + 2 := by - intro a - calc 1 + a + 1 = a + 1 + 1 := by conv => - rhs - rw [Nat.add_comm] - _ = a + 2 := by rw [Nat.add_assoc] +example : ∀ (a b c d: Nat), a + b = b + c → b + c = c + d → a + b = c + d := by + intro a b c d h1 h2 + calc a + b = b + c := by apply h1 + _ = c + d := by apply h2 def test_calc: TestM Unit := do - let state? ← startProof (.expr "∀ (a: Nat), 1 + a + 1 = a + 2") + let state? ← startProof (.expr "∀ (a b c d: Nat), a + b = b + c → b + c = c + d → a + b = c + d") let state0 ← match state? with | .some state => pure state | .none => do addTest $ assertUnreachable "Goal could not parse" return () - let tactic := "intro a" + let tactic := "intro a b c d h1 h2" let state1 ← match ← state0.tryTactic (goalId := 0) (tactic := tactic) with | .success state => pure state | other => do addTest $ assertUnreachable $ other.toString return () addTest $ LSpec.check tactic ((← state1.serializeGoals (options := ← read)).map (·.devolatilize) = - #[buildGoal [("a", "Nat")] "1 + a + 1 = a + 2"]) - let tactic := "calc" - let state2 ← match ← state1.tryTactic (goalId := 0) (tactic := tactic) with + #[interiorGoal [] "a + b = c + d"]) + let pred := "a + b = b + c" + let state2 ← match ← state1.tryCalc (goalId := 0) (pred := pred) with | .success state => pure state | other => do addTest $ assertUnreachable $ other.toString return () - addTest $ LSpec.check tactic ((← state1.serializeGoals (options := ← read)).map (·.devolatilize) = - #[buildGoal [("a", "Nat")] "1 + a + 1 = a + 2"]) + addTest $ LSpec.check s!"calc {pred} := _" ((← state2.serializeGoals (options := ← read)).map (·.devolatilize) = + #[ + interiorGoal [] "a + b = b + c" (.some "calc"), + interiorGoal [] "b + c = c + d" + ]) + + let tactic := "apply h1" + let state2m ← match ← state2.tryTactic (goalId := 0) (tactic := tactic) with + | .success state => pure state + | other => do + addTest $ assertUnreachable $ other.toString + return () + let state3 ← match state2m.continue state2 with + | .ok state => pure state + | .error e => do + addTest $ expectationFailure "continue" e + return () + let pred := "_ = c + d" + let state4 ← match ← state3.tryCalc (goalId := 0) (pred := pred) with + | .success state => pure state + | other => do + addTest $ assertUnreachable $ other.toString + return () + addTest $ LSpec.check s!"calc {pred} := _" ((← state4.serializeGoals (options := ← read)).map (·.devolatilize) = + #[ + interiorGoal [] "b + c = c + d" (.some "calc") + ]) + let tactic := "apply h2" + let state4m ← match ← state4.tryTactic (goalId := 0) (tactic := tactic) with + | .success state => pure state + | other => do + addTest $ assertUnreachable $ other.toString + return () + addTest $ LSpec.test "(4m root)" state4m.rootExpr?.isSome + + + where + interiorGoal (free: List (String × String)) (target: String) (userName?: Option String := .none) := + let free := [("a", "Nat"), ("b", "Nat"), ("c", "Nat"), ("d", "Nat"), + ("h1", "a + b = b + c"), ("h2", "b + c = c + d")] ++ free + buildGoal free target userName? def suite (env: Environment): List (String × IO LSpec.TestSeq) := let tests := [