fix: Decoupling of mvars during instantiation

This commit is contained in:
Leni Aniva 2024-05-19 15:43:10 -07:00
parent 6ad24b72d4
commit 2f951c8fef
Signed by: aniva
GPG Key ID: 4D9B1C8D10EA4C50
5 changed files with 122 additions and 65 deletions

View File

@ -12,7 +12,8 @@ def unfoldAuxLemmas (e : Expr) : CoreM Expr := do
/--
Force the instantiation of delayed metavariables even if they cannot be fully
instantiated. This is used during resumption.
instantiated. This is used during resumption to provide diagnostic data about
the current goal.
Since Lean 4 does not have an `Expr` constructor corresponding to delayed
metavariables, any delayed metavariables must be recursively handled by this
@ -25,59 +26,72 @@ This function ensures any metavariable in the result is either
2. Not assigned (delay or not)
-/
partial def instantiateDelayedMVars (eOrig: Expr) : MetaM Expr := do
--let padding := String.join $ List.replicate level " "
--let padding := String.join $ List.replicate level " "
--IO.println s!"{padding}Starting {toString eOrig}"
let result ← Meta.transform (← instantiateMVars eOrig)
let mut result ← Meta.transform (← instantiateMVars eOrig)
(pre := fun e => e.withApp fun f args => do
--IO.println s!"{padding} V {toString e}"
if let .mvar mvarId := f then
if ← mvarId.isAssigned then
--IO.println s!"{padding} A ?{mvarId.name}"
return .continue <| .some (← self e)
if let some { fvars, mvarIdPending } ← getDelayedMVarAssignment? mvarId then
-- No progress can be made on this
if !(← mvarIdPending.isAssigned) then
--IO.println s!"{padding} D/T1: {toString e}"
let .mvar mvarId := f | return .continue
--IO.println s!"{padding}├V {e}"
let mvarDecl ← mvarId.getDecl
-- This is critical to maintaining the interdependency of metavariables.
-- Without setting `.syntheticOpaque`, Lean's metavariable elimination
-- system will not make the necessary delayed assigned mvars in case of
-- nested mvars.
mvarId.setKind .syntheticOpaque
let lctx ← MonadLCtx.getLCtx
if mvarDecl.lctx.any (λ decl => !lctx.contains decl.fvarId) then
let violations := mvarDecl.lctx.decls.foldl (λ acc decl? => match decl? with
| .some decl => if lctx.contains decl.fvarId then acc else acc ++ [decl.fvarId.name]
| .none => acc) []
panic! s!"Local context variable violation: {violations}"
if let .some assign ← getExprMVarAssignment? mvarId then
--IO.println s!"{padding}├A ?{mvarId.name}"
assert! !(← mvarId.isDelayedAssigned)
return .visit (mkAppN assign args)
else if let some { fvars, mvarIdPending } ← getDelayedMVarAssignment? mvarId then
--let substTableStr := String.intercalate ", " $ Array.zipWith fvars args (λ fvar assign => s!"{fvar.fvarId!.name} := {assign}") |>.toList
--IO.println s!"{padding}├MD ?{mvarId.name} := ?{mvarIdPending.name} [{substTableStr}]"
if args.size < fvars.size then
throwError "Not enough arguments to instantiate a delay assigned mvar. This is due to bad implementations of a tactic: {args.size} < {fvars.size}. Expr: {toString e}; Origin: {toString eOrig}"
--if !args.isEmpty then
--IO.println s!"{padding}├── Arguments Begin"
let args ← args.mapM self
--if !args.isEmpty then
--IO.println s!"{padding}├── Arguments End"
if !(← mvarIdPending.isAssignedOrDelayedAssigned) then
--IO.println s!"{padding}├T1"
let result := mkAppN f args
return .done result
--IO.println s!"{padding} D ?{mvarId.name} := [{fvars.size}] ?{mvarIdPending.name}"
-- This asstion fails when a tactic or elaboration function is
-- implemented incorrectly. See `instantiateExprMVars`
if args.size < fvars.size then
--IO.println s!"{padding} Illegal callsite: {args.size} < {fvars.size}"
throwError "Not enough arguments to instantiate a delay assigned mvar. This is due to bad implementations of a tactic: {args.size} < {fvars.size}. Expr: {toString e}; Origin: {toString eOrig}"
assert! !(← mvarIdPending.isDelayedAssigned)
let pending ← self (.mvar mvarIdPending)
if pending == .mvar mvarIdPending then
-- No progress could be made on this case
--IO.println s!"{padding}D/N {toString e}"
assert! !(← mvarIdPending.isAssigned)
assert! !(← mvarIdPending.isDelayedAssigned)
--else if pending.isMVar then
-- assert! !(← pending.mvarId!.isAssigned)
-- assert! !(← pending.mvarId!.isDelayedAssigned)
-- -- Progress made, but this is now another delayed assigned mvar
-- let nextMVarId ← mkFreshMVarId
-- assignDelayedMVar nextMVarId fvars pending.mvarId!
-- let args ← args.mapM self
-- let result := mkAppN (.mvar nextMVarId) args
-- return .done result
else
-- Progress has been made on this mvar
let pending := pending.abstract fvars
let args ← args.mapM self
-- Craete the function call structure
let subst := pending.instantiateRevRange 0 fvars.size args
let result := mkAppRange subst fvars.size args.size args
--IO.println s!"{padding}D/T2 {toString result}"
let pending ← mvarIdPending.withContext do
let inner ← instantiateDelayedMVars (.mvar mvarIdPending) --(level := level + 1)
--IO.println s!"{padding}├Pre: {inner}"
let r := (← Expr.abstractM inner fvars).instantiateRev args
pure r
-- Tail arguments
let result := mkAppN pending (List.drop fvars.size args.toList |>.toArray)
--IO.println s!"{padding}├MD {result}"
return .done result
return .continue)
--IO.println s!"{padding}Result {toString result}"
else
assert! !(← mvarId.isAssigned)
assert! !(← mvarId.isDelayedAssigned)
--if !args.isEmpty then
-- IO.println s!"{padding}├── Arguments Begin"
let args ← args.mapM self
--if !args.isEmpty then
-- IO.println s!"{padding}├── Arguments End"
--IO.println s!"{padding}├M ?{mvarId.name}"
return .done (mkAppN f args))
--IO.println s!"{padding}└Result {result}"
return result
where
self e := instantiateDelayedMVars e
self e := instantiateDelayedMVars e --(level := level + 1)
/--
Convert an expression to an equiavlent form with

View File

@ -62,6 +62,16 @@ protected def GoalState.mctx (state: GoalState): MetavarContext :=
state.savedState.term.meta.meta.mctx
protected def GoalState.env (state: GoalState): Environment :=
state.savedState.term.meta.core.env
protected def GoalState.withContext (state: GoalState) (mvarId: MVarId) (m: MetaM α): MetaM α := do
let metaM := mvarId.withContext m
metaM.run' (← read) state.savedState.term.meta.meta
protected def GoalState.withParentContext (state: GoalState) (m: MetaM α): MetaM α := do
state.withContext state.parentMVar?.get! m
protected def GoalState.withRootContext (state: GoalState) (m: MetaM α): MetaM α := do
state.withContext state.root m
private def GoalState.mvars (state: GoalState): SSet MVarId :=
state.mctx.decls.foldl (init := .empty) fun acc k _ => acc.insert k
protected def GoalState.restoreMetaM (state: GoalState): MetaM Unit :=

View File

@ -179,9 +179,11 @@ def goalPrint (state: GoalState) (options: @&Protocol.Options): Lean.CoreM Proto
runMetaM do
state.restoreMetaM
return {
root? := ← state.rootExpr?.mapM (λ expr => do
root? := ← state.rootExpr?.mapM (λ expr =>
state.withRootContext do
serializeExpression options (← instantiateAll expr)),
parent? := ← state.parentExpr?.mapM (λ expr => do
parent? := ← state.parentExpr?.mapM (λ expr =>
state.withParentContext do
serializeExpression options (← instantiateAll expr)),
}
@[export pantograph_goal_diag_m]

View File

@ -264,7 +264,7 @@ def serializeGoal (options: @&Protocol.Options) (goal: MVarId) (mvarDecl: Metava
name := ofName goal.name,
userName? := if mvarDecl.userName == .anonymous then .none else .some (ofName mvarDecl.userName),
isConversion := isLHSGoal? mvarDecl.type |>.isSome,
target := (← serializeExpression options (← instantiate mvarDecl.type)),
target := (← serializeExpression options (← instantiateMVars (← instantiate mvarDecl.type))),
vars := vars.reverse.toArray
}
where

View File

@ -84,6 +84,27 @@ def proofRunner (env: Lean.Environment) (tests: TestM Unit): IO LSpec.TestSeq :=
| .ok (_, a) =>
return a
def test_identity: TestM Unit := do
let state? ← startProof (.expr "∀ (p: Prop), p → p")
let state0 ← match state? with
| .some state => pure state
| .none => do
addTest $ assertUnreachable "Goal could not parse"
return ()
let tactic := "intro p h"
let state1 ← match ← state0.tryTactic (goalId := 0) (tactic := tactic) with
| .success state => pure state
| other => do
addTest $ assertUnreachable $ other.toString
return ()
let inner := "_uniq.12"
addTest $ LSpec.check tactic ((← state1.serializeGoals (options := ← read)).map (·.name) =
#[inner])
let state1parent ← state1.withParentContext do
serializeExpressionSexp (← instantiateAll state1.parentExpr?.get!) (sanitize := false)
addTest $ LSpec.test "(1 parent)" (state1parent == s!"(:lambda p (:sort 0) (:lambda h 0 (:subst (:mv {inner}) 1 0)))")
-- Individual test cases
example: ∀ (a b: Nat), a + b = b + a := by
intro n m
@ -243,8 +264,9 @@ def test_or_comm: TestM Unit := do
addTest $ LSpec.check "(1 parent)" state1.parentExpr?.isSome
addTest $ LSpec.check "(1 root)" state1.rootExpr?.isNone
let state1parent ← serializeExpressionSexp (← instantiateAll state1.parentExpr?.get!) (sanitize := false)
addTest $ LSpec.test "(1 parent)" (state1parent == s!"(:lambda p (:sort 0) (:lambda q (:sort 0) (:lambda h ((:c Or) 1 0) (:subst (:mv {state1g0}) (:fv {fvP}) (:fv {fvQ}) 0))))")
let state1parent ← state1.withParentContext do
serializeExpressionSexp (← instantiateAll state1.parentExpr?.get!) (sanitize := false)
addTest $ LSpec.test "(1 parent)" (state1parent == s!"(:lambda p (:sort 0) (:lambda q (:sort 0) (:lambda h ((:c Or) 1 0) (:subst (:mv {state1g0}) 2 1 0))))")
let tactic := "cases h"
let state2 ← match ← state1.tryTactic (goalId := 0) (tactic := tactic) with
| .success state => pure state
@ -253,25 +275,31 @@ def test_or_comm: TestM Unit := do
return ()
addTest $ LSpec.check tactic ((← state2.serializeGoals (options := ← read)).map (·.devolatilize) =
#[branchGoal "inl" "p", branchGoal "inr" "q"])
let (caseL, caseR) := ("_uniq.62", "_uniq.75")
let (caseL, caseR) := ("_uniq.64", "_uniq.77")
addTest $ LSpec.check tactic ((← state2.serializeGoals (options := ← read)).map (·.name) =
#[caseL, caseR])
addTest $ LSpec.check "(2 parent)" state2.parentExpr?.isSome
addTest $ LSpec.check "(2 parent exists)" state2.parentExpr?.isSome
addTest $ LSpec.check "(2 root)" state2.rootExpr?.isNone
let state2parent ← serializeExpressionSexp (← instantiateAll state2.parentExpr?.get!) (sanitize := false)
let state2parent ← state2.withParentContext do
serializeExpressionSexp (← instantiateAll state2.parentExpr?.get!) (sanitize := false)
let orPQ := s!"((:c Or) (:fv {fvP}) (:fv {fvQ}))"
let orQP := s!"((:c Or) (:fv {fvQ}) (:fv {fvP}))"
let motive := s!"(:lambda t._@._hyg.26 {orPQ} (:forall h ((:c Eq) ((:c Or) (:fv {fvP}) (:fv {fvQ})) (:fv {fvH}) 0) {orQP}))"
let caseL := s!"(:lambda h._@._hyg.27 (:fv {fvP}) (:lambda h._@._hyg.28 ((:c Eq) {orPQ} (:fv {fvH}) ((:c Or.inl) (:fv {fvP}) (:fv {fvQ}) 0)) (:subst (:mv {caseL}) (:fv {fvP}) (:fv {fvQ}) 1)))"
let caseR := s!"(:lambda h._@._hyg.29 (:fv {fvQ}) (:lambda h._@._hyg.30 ((:c Eq) {orPQ} (:fv {fvH}) ((:c Or.inr) (:fv {fvP}) (:fv {fvQ}) 0)) (:subst (:mv {caseR}) (:fv {fvP}) (:fv {fvQ}) 1)))"
let conduit := s!"((:c Eq.refl) {orPQ} (:fv {fvH}))"
addTest $ LSpec.test "(2 parent)" (state2parent ==
s!"((:c Or.casesOn) (:fv {fvP}) (:fv {fvQ}) (:lambda t._@._hyg.26 {orPQ} (:forall h ((:c Eq) {orPQ} (:fv {fvH}) 0) {orQP})) (:fv {fvH}) (:lambda h._@._hyg.27 (:fv {fvP}) (:lambda h._@._hyg.28 ((:c Eq) {orPQ} (:fv {fvH}) ((:c Or.inl) (:fv {fvP}) (:fv {fvQ}) 0)) (:mv {caseL}))) (:lambda h._@._hyg.29 (:fv {fvQ}) (:lambda h._@._hyg.30 ((:c Eq) {orPQ} (:fv {fvH}) ((:c Or.inr) (:fv {fvP}) (:fv {fvQ}) 0)) (:mv {caseR}))) ((:c Eq.refl) {orPQ} (:fv {fvH})))")
s!"((:c Or.casesOn) (:fv {fvP}) (:fv {fvQ}) {motive} (:fv {fvH}) {caseL} {caseR} {conduit})")
let state3_1 ← match ← state2.tryTactic (goalId := 0) (tactic := "apply Or.inr") with
| .success state => pure state
| other => do
addTest $ assertUnreachable $ other.toString
return ()
let state3_1parent ← serializeExpressionSexp (← instantiateAll state3_1.parentExpr?.get!) (sanitize := false)
addTest $ LSpec.test "(3_1 parent)" (state3_1parent == s!"((:c Or.inr) (:fv {fvQ}) (:fv {fvP}) (:mv _uniq.87))")
let state3_1parent ← state3_1.withParentContext do
serializeExpressionSexp (← instantiateAll state3_1.parentExpr?.get!) (sanitize := false)
addTest $ LSpec.test "(3_1 parent)" (state3_1parent == s!"((:c Or.inr) (:fv {fvQ}) (:fv {fvP}) (:mv _uniq.91))")
addTest $ LSpec.check "· apply Or.inr" (state3_1.goals.length = 1)
let state4_1 ← match ← state3_1.tryTactic (goalId := 0) (tactic := "assumption") with
| .success state => pure state
@ -800,14 +828,16 @@ def test_nat_zero_add_alt: TestM Unit := do
let cNatAdd := "(:c HAdd.hAdd) (:c Nat) (:c Nat) (:c Nat) ((:c instHAdd) (:c Nat) (:c instAddNat))"
let cNat0 := "((:c OfNat.ofNat) (:c Nat) (:lit 0) ((:c instOfNatNat) (:lit 0)))"
let fvN := "_uniq.63"
let conduitRight := s!"((:c Eq) (:c Nat) ({cNatAdd} (:fv {fvN}) {cNat0}) (:fv {fvN}))"
let substOf (mv: String) := s!"(:subst (:mv {mv}) (:fv {fvN}) (:mv {major}))"
addTest $ LSpec.check "resume" ((← state2b.serializeGoals (options := { ← read with printExprAST := true })) =
#[
{
name := "_uniq.70",
userName? := .some "conduit",
target := {
pp? := .some "(?motive.a = ?motive.a) = (n + 0 = n)",
sexp? := .some s!"((:c Eq) (:sort 0) ((:c Eq) (:mv {eqT}) (:mv {eqL}) (:mv {eqR})) ((:c Eq) (:c Nat) ({cNatAdd} (:fv {fvN}) {cNat0}) (:fv {fvN})))",
pp? := .some "(?m.92 ?m.68 = ?m.94 ?m.68) = (n + 0 = n)",
sexp? := .some s!"((:c Eq) (:sort 0) ((:c Eq) {substOf eqT} {substOf eqL} {substOf eqR}) {conduitRight})",
},
vars := #[{
name := fvN,
@ -820,6 +850,7 @@ def test_nat_zero_add_alt: TestM Unit := do
def suite (env: Environment): List (String × IO LSpec.TestSeq) :=
let tests := [
("identity", test_identity),
("Nat.add_comm", test_nat_add_comm false),
("Nat.add_comm manual", test_nat_add_comm true),
("Nat.add_comm delta", test_delta_variable),