diff --git a/Pantograph/Goal.lean b/Pantograph/Goal.lean index d2ba412..e20deb3 100644 --- a/Pantograph/Goal.lean +++ b/Pantograph/Goal.lean @@ -185,13 +185,17 @@ protected def GoalState.getMVarEAssignment (goalState: GoalState) (mvarId: MVarI let (expr, _) := instantiateMVarsCore (mctx := goalState.mctx) (e := expr) return expr +deriving instance BEq for DelayedMetavarAssignment + /-- Given states `dst`, `src`, and `src'`, where `dst` and `src'` are descendants of `src`, replay the differential `src' - src` in `dst`. Colliding metavariable and lemma names will be automatically renamed to ensure there is no collision. This implements branch unification. Unification might be impossible -if conflicting assignments exist. -/ +if conflicting assignments exist. We also assume the monotonicity property: In a +chain of descending goal states, a mvar cannot be unassigned, and once assigned +its assignment cannot change. -/ @[export pantograph_goal_state_replay_m] -protected def GoalState.replay (dst : GoalState) (src src' : GoalState) : CoreM (Option GoalState) := do +protected def GoalState.replay (dst : GoalState) (src src' : GoalState) : CoreM GoalState := do let srcNGen := src.coreState.ngen let srcNGen' := src'.coreState.ngen assert! srcNGen.namePrefix == srcNGen'.namePrefix @@ -200,23 +204,33 @@ protected def GoalState.replay (dst : GoalState) (src src' : GoalState) : CoreM assert! src.mctx.depth == dst.mctx.depth let diffNGenIdx := dst.coreState.ngen.idx - srcNGen.idx - --let diffMVarIds := src'.mctx.decls.map λ decl => decl.index ≥ src.mctx.mvarCounter + -- True if the name is generated after `src` + let isNewName : Name → Bool + | .num pref n => + pref == srcNGen.namePrefix ∧ n ≥ srcNGen.idx + | _ => false let mapId : Name → Name | id@(.num pref n) => - if pref == srcNGen.namePrefix ∧ n ≥ srcNGen.idx then + if isNewName id then .num pref (n + diffNGenIdx) else id | id => id - let mapExpr (e : Expr) : CoreM Expr := Core.transform e λ - | .mvar { name } => pure $ .done $ .mvar ⟨mapId name⟩ - | _ => pure .continue let rec mapLevel : Level → Level | .succ x => .succ (mapLevel x) | .max l1 l2 => .max (mapLevel l1) (mapLevel l2) | .imax l1 l2 => .imax (mapLevel l1) (mapLevel l2) | .mvar { name } => .mvar ⟨mapId name⟩ | l => l + let mapExpr (e : Expr) : CoreM Expr := Core.transform e λ + | .mvar { name } => pure $ .done $ .mvar ⟨mapId name⟩ + | _ => pure .continue + let mapDelayedAssignment (d : DelayedMetavarAssignment) : CoreM DelayedMetavarAssignment := do + let { mvarIdPending, fvars } := d + return { + mvarIdPending := ⟨mapId mvarIdPending.name⟩, + fvars := ← fvars.mapM mapExpr, + } let mapLocalDecl (ldecl : LocalDecl) : CoreM LocalDecl := do let ldecl := ldecl.setType (← mapExpr ldecl.type) if let .some value := ldecl.value? then @@ -245,28 +259,95 @@ protected def GoalState.replay (dst : GoalState) (src src' : GoalState) : CoreM type := ← mapExpr decl.type, } return acc.insert mvarId decl + + -- Merge mvar assignments + userNames := src'.mctx.userNames.foldl (init := mctx.userNames) λ acc userName mvarId => + if acc.contains userName then + acc + else + acc.insert userName mvarId, + lAssignment := src'.mctx.lAssignment.foldl (init := mctx.lAssignment) λ acc lmvarId' l => + let lmvarId := ⟨mapId lmvarId'.name⟩ + if mctx.lAssignment.contains lmvarId then + -- Skip the intersecting assignments for now + acc + else + let l := mapLevel l + acc.insert lmvarId l, + eAssignment := ← src'.mctx.eAssignment.foldlM (init := mctx.eAssignment) λ acc mvarId' e => do + let mvarId := ⟨mapId mvarId'.name⟩ + if mctx.eAssignment.contains mvarId then + -- Skip the intersecting assignments for now + return acc + else + let e ← mapExpr e + return acc.insert mvarId e, + dAssignment := ← src'.mctx.dAssignment.foldlM (init := mctx.dAssignment) λ acc mvarId' d => do + let mvarId := ⟨mapId mvarId'.name⟩ + if mctx.dAssignment.contains mvarId then + return acc + else + let d ← mapDelayedAssignment d + return acc.insert mvarId d } let ngen := { core.ngen with idx := core.ngen.idx + (srcNGen'.idx - srcNGen.idx) } - return .some { + -- Merge conflicting lmvar and mvar assignments using `isDefEq` + + let savedMeta := { + savedMeta with + core := { + core with + ngen, + } + meta := { + meta with + mctx, + } + } + let m : MetaM Meta.SavedState := Meta.withMCtx mctx do + savedMeta.restore + + for (lmvarId, l') in src'.mctx.lAssignment do + if isNewName lmvarId.name then + continue + let .some l ← getLevelMVarAssignment? lmvarId | continue + let l' := mapLevel l' + unless ← Meta.isLevelDefEq l l' do + throwError "Conflicting assignment of level metavariable {lmvarId.name}" + for (mvarId, e') in src'.mctx.eAssignment do + if isNewName mvarId.name then + continue + if ← mvarId.isDelayedAssigned then + throwError "Conflicting assignment of expr metavariable (e != d) {mvarId.name}" + let .some e ← getExprMVarAssignment? mvarId | continue + let e' ← mapExpr e' + unless ← Meta.isDefEq e e' do + throwError "Conflicting assignment of expr metavariable (e != e) {mvarId.name}" + for (mvarId, d') in src'.mctx.dAssignment do + if isNewName mvarId.name then + continue + if ← mvarId.isAssigned then + throwError "Conflicting assignment of expr metavariable (d != e) {mvarId.name}" + let .some d ← getDelayedMVarAssignment? mvarId | continue + unless d == d' do + throwError "Conflicting assignment of expr metavariable (d != d) {mvarId.name}" + + Meta.saveState + -- FIXME: Handle calc goals + let goals :=dst.savedState.tactic.goals ++ + src'.savedState.tactic.goals.map (⟨mapId ·.name⟩) + return { dst with savedState := { - dst.savedState with + tactic := { + goals + }, term := { savedTerm with - meta := { - savedMeta with - core := { - core with - ngen, - } - meta := { - meta with - mctx, - } - } + meta := ← m.run', }, }, } diff --git a/Test/Metavar.lean b/Test/Metavar.lean index 00621e6..c2259ed 100644 --- a/Test/Metavar.lean +++ b/Test/Metavar.lean @@ -253,8 +253,8 @@ def test_partial_continuation: TestM Unit := do addTest $ assertUnreachable $ msg return () | .ok state => pure state - addTest $ LSpec.check "(continue 2)" ((← state1b.serializeGoals (options := ← read)).map (·.target.pp?) = - #[.some "2 ≤ Nat.succ ?m", .some "Nat.succ ?m ≤ 5", .some "Nat"]) + checkEq "(continue 2)" ((← state1b.serializeGoals (options := ← read)).map (·.target.pp?)) + #[.some "2 ≤ Nat.succ ?m", .some "Nat.succ ?m ≤ 5", .some "Nat"] checkTrue "(2 root)" state1b.rootExpr?.get!.hasExprMVar -- Continuation should fail if the state does not exist: @@ -268,13 +268,19 @@ def test_partial_continuation: TestM Unit := do return () def test_branch_unification : TestM Unit := do - let .ok expr ← elabTerm (← `(term|∀ (p q : Prop), p → p ∧ (p ∨ q))) .none | unreachable! - Meta.forallTelescope expr $ λ _ rootTarget => do - let state ← GoalState.create rootTarget - let .success state1 _ ← state.tacticOn 0 "exact p" | unreachable! - let .success state2 _ ← state.tacticOn 1 "apply Or.inl" | unreachable! - let state' := state2.replay state state1 - return () + let .ok rootTarget ← elabTerm (← `(term|∀ (p q : Prop), p → p ∧ (p ∨ q))) .none | unreachable! + let state ← GoalState.create rootTarget + let .success state _ ← state.tacticOn' 0 (← `(tactic|intro p q h)) | fail "intro failed to run" + let .success state _ ← state.tacticOn' 0 (← `(tactic|apply And.intro)) | fail "apply And.intro failed to run" + let .success state1 _ ← state.tacticOn' 0 (← `(tactic|exact h)) | fail "exact h failed to run" + let .success state2 _ ← state.tacticOn' 1 (← `(tactic|apply Or.inl)) | fail "apply Or.inl failed to run" + assert! state2.goals.length == 1 + let state' ← state2.replay state state1 + assert! state'.goals.length == 1 + let .success stateT _ ← state'.tacticOn' 0 (← `(tactic|exact h)) | fail "exact h failed to run" + let .some root := stateT.rootExpr? | fail "Root expression must exist" + checkEq "(root)" (toString $ ← Meta.ppExpr root) "fun p q h => ⟨h, Or.inl h⟩" + return () def suite (env: Environment): List (String × IO LSpec.TestSeq) := let tests := [