From 63e64a1e9f62efd6837f9222b60fe06857346117 Mon Sep 17 00:00:00 2001 From: Leni Aniva Date: Mon, 8 Apr 2024 12:26:22 -0700 Subject: [PATCH] feat: Conv tactic functions --- Pantograph.lean | 2 + Pantograph/Goal.lean | 112 +++++++++++++++++++++++++++++++++---------- Test/Common.lean | 1 + Test/Proofs.lean | 106 ++++++++++++++++++++++++++++++---------- 4 files changed, 171 insertions(+), 50 deletions(-) diff --git a/Pantograph.lean b/Pantograph.lean index 70d64b9..626afae 100644 --- a/Pantograph.lean +++ b/Pantograph.lean @@ -140,6 +140,8 @@ def execute (command: Protocol.Command): MainM Lean.Json := do return .ok { parseError? := .some message } | .ok (.indexError goalId) => return .error $ errorIndex s!"Invalid goal id index {goalId}" + | .ok (.invalidAction message) => + return .error $ errorI "invalid" message | .ok (.failure messages) => return .ok { tacticErrors? := .some messages } goal_continue (args: Protocol.GoalContinue): MainM (CR Protocol.GoalContinueResult) := do diff --git a/Pantograph/Goal.lean b/Pantograph/Goal.lean index b238332..78affd7 100644 --- a/Pantograph/Goal.lean +++ b/Pantograph/Goal.lean @@ -31,6 +31,9 @@ structure GoalState where -- Parent state metavariable source parentMVar: Option MVarId + -- Existence of this field shows that we are currently in `conv` mode. + convMVar: Option (MVarId × MVarId × List MVarId) := .none + protected def GoalState.create (expr: Expr): Elab.TermElabM GoalState := do -- May be necessary to immediately synthesise all metavariables if we need to leave the elaboration context. -- See https://leanprover.zulipchat.com/#narrow/stream/270676-lean4/topic/Unknown.20universe.20metavariable/near/360130070 @@ -100,6 +103,8 @@ inductive TacticResult where | parseError (message: String) -- The goal index is out of bounds | indexError (goalId: Nat) + -- The given action cannot be executed in the state + | invalidAction (message: String) /-- Execute tactic on given state -/ protected def GoalState.tryTactic (state: GoalState) (goalId: Nat) (tactic: String): @@ -122,11 +127,11 @@ protected def GoalState.tryTactic (state: GoalState) (goalId: Nat) (tactic: Stri | .ok nextSavedState => -- Assert that the definition of metavariables are the same let nextMCtx := nextSavedState.term.meta.meta.mctx - let prevMCtx := state.savedState.term.meta.meta.mctx + let prevMCtx := state.mctx -- Generate a list of mvarIds that exist in the parent state; Also test the -- assertion that the types have not changed on any mvars. return .success { - root := state.root, + state with savedState := nextSavedState newMVars := newMVarSet prevMCtx nextMCtx, parentMVar := .some goal, @@ -146,7 +151,7 @@ protected def GoalState.assign (state: GoalState) (goal: MVarId) (expr: Expr): return .some s!"{← Meta.ppExpr expr} : {← Meta.ppExpr exprType} != {← Meta.ppExpr goalType}" ) if let .some error := error? then - return .failure #["Type unification failed", error] + return .parseError error goal.checkNotAssigned `GoalState.assign goal.assign expr if (← getThe Core.State).messages.hasErrors then @@ -246,35 +251,45 @@ protected def GoalState.tryHave (state: GoalState) (goalId: Nat) (binderName: St return .failure #[← exception.toMessageData.toString] /-- Enter conv tactic mode -/ -protected def GoalState.tryConv (state: GoalState) (goalId: Nat): +protected def GoalState.conv (state: GoalState) (goalId: Nat): Elab.TermElabM TacticResult := do + if state.convMVar.isSome then + return .invalidAction "Already in conv state" let goal ← match state.savedState.tactic.goals.get? goalId with | .some goal => pure goal | .none => return .indexError goalId - let tacticM : Elab.Tactic.TacticM Elab.Tactic.SavedState:= do + let tacticM : Elab.Tactic.TacticM (Elab.Tactic.SavedState × MVarId) := do state.restoreTacticM goal -- TODO: Fail if this is already in conv -- See Lean.Elab.Tactic.Conv.convTarget - Elab.Tactic.withMainContext do + let convMVar ← Elab.Tactic.withMainContext do -- TODO: Memorize this `rhs` as a conv resultant goal let (rhs, newGoal) ← Elab.Tactic.Conv.mkConvGoalFor (← Elab.Tactic.getMainTarget) Elab.Tactic.setGoals [newGoal.mvarId!] - --Elab.Tactic.liftMetaTactic1 fun mvarId => mvarId.replaceTargetEq rhs proof - MonadBacktrack.saveState - let nextSavedState ← tacticM { elaborator := .anonymous } |>.run' state.savedState.tactic - let prevMCtx := state.savedState.term.meta.meta.mctx - let nextMCtx := nextSavedState.term.meta.meta.mctx - return .success { - root := state.root, - savedState := nextSavedState - newMVars := newMVarSet prevMCtx nextMCtx, - parentMVar := .some goal, - } + pure rhs.mvarId! + return (← MonadBacktrack.saveState, convMVar) + try + let (nextSavedState, convRhs) ← tacticM { elaborator := .anonymous } |>.run' state.savedState.tactic + let prevMCtx := state.mctx + let nextMCtx := nextSavedState.term.meta.meta.mctx + return .success { + root := state.root, + savedState := nextSavedState + newMVars := newMVarSet prevMCtx nextMCtx, + parentMVar := .some goal, + convMVar := .some (convRhs, goal, state.goals), + } + catch exception => + return .failure #[← exception.toMessageData.toString] +/-- Execute a tactic in conv mode -/ protected def GoalState.tryConvTactic (state: GoalState) (goalId: Nat) (convTactic: String): Elab.TermElabM TacticResult := do + let _ ← match state.convMVar with + | .some mvar => pure mvar + | .none => return .invalidAction "Not in conv state" let goal ← match state.savedState.tactic.goals.get? goalId with | .some goal => pure goal | .none => return .indexError goalId @@ -289,15 +304,60 @@ protected def GoalState.tryConvTactic (state: GoalState) (goalId: Nat) (convTact state.restoreTacticM goal Elab.Tactic.evalTactic convTactic MonadBacktrack.saveState - let nextSavedState ← tacticM { elaborator := .anonymous } |>.run' state.savedState.tactic - let nextMCtx := nextSavedState.term.meta.meta.mctx - let prevMCtx := state.savedState.term.meta.meta.mctx - return .success { - root := state.root, - savedState := nextSavedState - newMVars := newMVarSet prevMCtx nextMCtx, - parentMVar := .some goal, - } + try + let prevMCtx := state.mctx + let nextSavedState ← tacticM { elaborator := .anonymous } |>.run' state.savedState.tactic + let nextMCtx := nextSavedState.term.meta.meta.mctx + return .success { + state with + savedState := nextSavedState + newMVars := newMVarSet prevMCtx nextMCtx, + parentMVar := .some goal, + } + catch exception => + return .failure #[← exception.toMessageData.toString] + +protected def GoalState.convExit (state: GoalState): + Elab.TermElabM TacticResult := do + let (convRhs, convGoal, savedGoals) ← match state.convMVar with + | .some mvar => pure mvar + | .none => return .invalidAction "Not in conv state" + let tacticM : Elab.Tactic.TacticM Elab.Tactic.SavedState:= do + -- Vide `Lean.Elab.Tactic.Conv.convert` + state.savedState.restore + + IO.println "Restored state" + + -- Close all existing goals with `refl` + for mvarId in (← Elab.Tactic.getGoals) do + liftM <| mvarId.refl <|> mvarId.inferInstance <|> pure () + Elab.Tactic.pruneSolvedGoals + unless (← Elab.Tactic.getGoals).isEmpty do + throwError "convert tactic failed, there are unsolved goals\n{Elab.goalsToMessageData (← Elab.Tactic.getGoals)}" + + IO.println "Caching" + Elab.Tactic.setGoals savedGoals + + let targetNew ← instantiateMVars (.mvar convRhs) + let proof ← instantiateMVars (.mvar convGoal) + + Elab.Tactic.liftMetaTactic1 fun mvarId => mvarId.replaceTargetEq targetNew proof + MonadBacktrack.saveState + try + let nextSavedState ← tacticM { elaborator := .anonymous } |>.run' state.savedState.tactic + IO.println "Finished caching" + let nextMCtx := nextSavedState.term.meta.meta.mctx + let prevMCtx := state.savedState.term.meta.meta.mctx + return .success { + root := state.root, + savedState := nextSavedState + newMVars := newMVarSet prevMCtx nextMCtx, + parentMVar := .some convGoal, + convMVar := .none + } + catch exception => + return .failure #[← exception.toMessageData.toString] + /-- diff --git a/Test/Common.lean b/Test/Common.lean index 6fa858b..8719ebd 100644 --- a/Test/Common.lean +++ b/Test/Common.lean @@ -37,6 +37,7 @@ def TacticResult.toString : TacticResult → String s!".failure {messages}" | .parseError error => s!".parseError {error}" | .indexError index => s!".indexError {index}" + | .invalidAction error => s!".invalidAction {error}" namespace Test diff --git a/Test/Proofs.lean b/Test/Proofs.lean index 4b2b57e..c8ceeee 100644 --- a/Test/Proofs.lean +++ b/Test/Proofs.lean @@ -361,47 +361,48 @@ def test_have: TestM Unit := do addTest $ LSpec.check "(4 root)" state4.rootExpr?.isSome -example : ∀ (a b c: Nat), (a + b) + c = (b + a) + c := by - intro a b c +example : ∀ (a b c1 c2: Nat), (b + a) + c1 = (b + a) + c2 → (a + b) + c1 = (b + a) + c2 := by + intro a b c1 c2 h conv => lhs congr - rw [Nat.add_comm] - rfl + . rw [Nat.add_comm] + . rfl + exact h def test_conv: TestM Unit := do - let state? ← startProof (.expr "∀ (a b c: Nat), (a + b) + c = (b + a) + c") + let state? ← startProof (.expr "∀ (a b c1 c2: Nat), (b + a) + c1 = (b + a) + c2 → (a + b) + c1 = (b + a) + c2") let state0 ← match state? with | .some state => pure state | .none => do addTest $ assertUnreachable "Goal could not parse" return () - let tactic := "intro a b c" + + let tactic := "intro a b c1 c2 h" let state1 ← match ← state0.tryTactic (goalId := 0) (tactic := tactic) with | .success state => pure state | other => do addTest $ assertUnreachable $ other.toString return () addTest $ LSpec.check tactic ((← state1.serializeGoals (options := ← read)).map (·.devolatilize) = - #[buildGoal [("a", "Nat"), ("b", "Nat"), ("c", "Nat")] "a + b + c = b + a + c"]) + #[interiorGoal [] "a + b + c1 = b + a + c2"]) - -- This solves the state in one-shot - let tactic := "conv => { lhs; congr; rw [Nat.add_comm]; rfl }" - let stateT ← match ← state1.tryTactic (goalId := 0) (tactic := tactic) with - | .success state => pure state - | other => do - addTest $ assertUnreachable $ other.toString - return () - addTest $ LSpec.check tactic ((← stateT.serializeGoals (options := ← read)).map (·.devolatilize) = - #[]) - - let state2 ← match ← state1.tryConv (goalId := 0) with + let state2 ← match ← state1.conv (goalId := 0) with | .success state => pure state | other => do addTest $ assertUnreachable $ other.toString return () addTest $ LSpec.check "conv => ..." ((← state2.serializeGoals (options := ← read)).map (·.devolatilize) = - #[{ buildGoal [("a", "Nat"), ("b", "Nat"), ("c", "Nat")] "a + b + c = b + a + c" with isConversion := true }]) + #[{ interiorGoal [] "a + b + c1 = b + a + c2" with isConversion := true }]) + + let convTactic := "rhs" + let state3R ← match ← state2.tryConvTactic (goalId := 0) (convTactic := convTactic) with + | .success state => pure state + | other => do + addTest $ assertUnreachable $ other.toString + return () + addTest $ LSpec.check s!" {convTactic} (discard)" ((← state3R.serializeGoals (options := ← read)).map (·.devolatilize) = + #[{ interiorGoal [] "b + a + c2" with isConversion := true }]) let convTactic := "lhs" let state3L ← match ← state2.tryConvTactic (goalId := 0) (convTactic := convTactic) with @@ -410,16 +411,73 @@ def test_conv: TestM Unit := do addTest $ assertUnreachable $ other.toString return () addTest $ LSpec.check s!" {convTactic}" ((← state3L.serializeGoals (options := ← read)).map (·.devolatilize) = - #[{ buildGoal [("a", "Nat"), ("b", "Nat"), ("c", "Nat")] "a + b + c" with isConversion := true }]) + #[{ interiorGoal [] "a + b + c1" with isConversion := true }]) - let convTactic := "rhs" - let state3R ← match ← state2.tryConvTactic (goalId := 0) (convTactic := convTactic) with + let convTactic := "congr" + let state4 ← match ← state3L.tryConvTactic (goalId := 0) (convTactic := convTactic) with | .success state => pure state | other => do addTest $ assertUnreachable $ other.toString return () - addTest $ LSpec.check s!" {convTactic}" ((← state3R.serializeGoals (options := ← read)).map (·.devolatilize) = - #[{ buildGoal [("a", "Nat"), ("b", "Nat"), ("c", "Nat")] "b + a + c" with isConversion := true }]) + addTest $ LSpec.check s!" {convTactic}" ((← state4.serializeGoals (options := ← read)).map (·.devolatilize) = + #[ + { interiorGoal [] "a + b" with isConversion := true, userName? := .some "a" }, + { interiorGoal [] "c1" with isConversion := true, userName? := .some "a" } + ]) + + let convTactic := "rw [Nat.add_comm]" + let state5_1 ← match ← state4.tryConvTactic (goalId := 0) (convTactic := convTactic) with + | .success state => pure state + | other => do + addTest $ assertUnreachable $ other.toString + return () + addTest $ LSpec.check s!" · {convTactic}" ((← state5_1.serializeGoals (options := ← read)).map (·.devolatilize) = + #[{ interiorGoal [] "b + a" with isConversion := true, userName? := .some "a" }]) + + let convTactic := "rfl" + let state6_1 ← match ← state5_1.tryConvTactic (goalId := 0) (convTactic := convTactic) with + | .success state => pure state + | other => do + addTest $ assertUnreachable $ other.toString + return () + addTest $ LSpec.check s!" {convTactic}" ((← state6_1.serializeGoals (options := ← read)).map (·.devolatilize) = + #[]) + + let state4_1 ← match state6_1.continue state4 with + | .ok state => pure state + | .error e => do + addTest $ expectationFailure "continue" e + return () + + let convTactic := "rfl" + let state6 ← match ← state4_1.tryConvTactic (goalId := 0) (convTactic := convTactic) with + | .success state => pure state + | other => do + addTest $ assertUnreachable $ other.toString + return () + addTest $ LSpec.check s!" · {convTactic}" ((← state6.serializeGoals (options := ← read)).map (·.devolatilize) = + #[]) + + let state1_1 ← match ← state6.convExit with + | .success state => pure state + | other => do + addTest $ assertUnreachable $ other.toString + return () + + let tactic := "exact h" + let stateF ← match ← state1_1.tryTactic (goalId := 0) (tactic := tactic) with + | .success state => pure state + | other => do + addTest $ assertUnreachable $ other.toString + return () + addTest $ LSpec.check tactic ((← stateF.serializeGoals (options := ← read)).map (·.devolatilize) = + #[]) + + where + h := "b + a + c1 = b + a + c2" + interiorGoal (free: List (String × String)) (target: String) := + let free := [("a", "Nat"), ("b", "Nat"), ("c1", "Nat"), ("c2", "Nat"), ("h", h)] ++ free + buildGoal free target example : ∀ (a: Nat), 1 + a + 1 = a + 2 := by intro a