diff --git a/Pantograph/Goal.lean b/Pantograph/Goal.lean index d0122f5..caf0267 100644 --- a/Pantograph/Goal.lean +++ b/Pantograph/Goal.lean @@ -185,6 +185,92 @@ protected def GoalState.getMVarEAssignment (goalState: GoalState) (mvarId: MVarI let (expr, _) := instantiateMVarsCore (mctx := goalState.mctx) (e := expr) return expr +/-- Given states `dst`, `src`, and `src'`, where `dst` and `src'` are +descendants of `src`, replay the differential `src' - src` in `dst`. Colliding +metavariable and lemma names will be automatically renamed to ensure there is no +collision. This implements branch unification. Unification might be impossible +if conflicting assignments exist. -/ +@[export pantograph_goal_state_replay] +protected def GoalState.replay (dst : GoalState) (src src' : GoalState) : CoreM (Option GoalState) := do + let srcNGen := src.coreState.ngen + let srcNGen' := src'.coreState.ngen + assert! srcNGen.namePrefix == srcNGen'.namePrefix + assert! srcNGen.namePrefix == dst.coreState.ngen.namePrefix + assert! src.mctx.depth == src'.mctx.depth + assert! src.mctx.depth == dst.mctx.depth + + let diffNGenIdx := dst.coreState.ngen.idx - srcNGen.idx + --let diffMVarIds := src'.mctx.decls.map λ decl => decl.index ≥ src.mctx.mvarCounter + let mapId : Name → Name + | id@(.num pref n) => + if pref == srcNGen.namePrefix ∧ n ≥ srcNGen.idx then + .num pref (n + diffNGenIdx) + else + id + | id => id + let mapExpr (e : Expr) : CoreM Expr := Core.transform e λ + | .mvar { name } => pure $ .done $ .mvar ⟨mapId name⟩ + | _ => pure .continue + let rec mapLevel : Level → Level + | .succ x => .succ (mapLevel x) + | .max l1 l2 => .max (mapLevel l1) (mapLevel l2) + | .imax l1 l2 => .imax (mapLevel l1) (mapLevel l2) + | .mvar { name } => .mvar ⟨mapId name⟩ + | l => l + let mapLocalDecl (ldecl : LocalDecl) : CoreM LocalDecl := do + let ldecl := ldecl.setType (← mapExpr ldecl.type) + if let .some value := ldecl.value? then + return ldecl.setValue (← mapExpr value) + else + return ldecl + + let { term := savedTerm@{ meta := savedMeta@{ core, meta := meta@{ mctx, .. } }, .. }, .. } := dst.savedState + let mctx := { + mctx with + mvarCounter := mctx.mvarCounter + (src'.mctx.mvarCounter - src.mctx.mvarCounter), + lDepth := src'.mctx.lDepth.foldl (init := mctx.lDepth) λ acc lmvarId@{ name } depth => + if src.mctx.lDepth.contains lmvarId then + acc + else + acc.insert ⟨mapId name⟩ depth + decls := ← src'.mctx.decls.foldlM (init := mctx.decls) λ acc _mvarId@{ name } decl => do + if decl.index < src.mctx.mvarCounter then + return acc + let mvarId := ⟨mapId name⟩ + let decl := { + decl with + lctx := ← decl.lctx.foldlM (init := .empty) λ acc decl => do + let decl ← mapLocalDecl decl + return acc.addDecl decl, + type := ← mapExpr decl.type, + } + return acc.insert mvarId decl + } + let ngen := { + core.ngen with + idx := core.ngen.idx + (srcNGen'.idx - srcNGen.idx) + } + return .some { + dst with + savedState := { + dst.savedState with + term := { + savedTerm with + meta := { + savedMeta with + core := { + core with + ngen, + } + meta := { + meta with + mctx, + } + } + }, + }, + } + --- Tactic execution functions --- -- Mimics `Elab.Term.logUnassignedUsingErrorInfos` diff --git a/Test/Metavar.lean b/Test/Metavar.lean index ddd6e56..00621e6 100644 --- a/Test/Metavar.lean +++ b/Test/Metavar.lean @@ -260,13 +260,21 @@ def test_partial_continuation: TestM Unit := do -- Continuation should fail if the state does not exist: match state0.resume coupled_goals with | .error error => addTest $ LSpec.check "(continuation failure message)" (error = "Goals [_uniq.44, _uniq.45, _uniq.42, _uniq.51] are not in scope") - | .ok _ => addTest $ assertUnreachable "(continuation failure)" + | .ok _ => fail "(continuation should fail)" -- Continuation should fail if some goals have not been solved match state2.continue state1 with | .error error => addTest $ LSpec.check "(continuation failure message)" (error = "Target state has unresolved goals") - | .ok _ => addTest $ assertUnreachable "(continuation failure)" + | .ok _ => fail "(continuation should fail)" return () +def test_branch_unification : TestM Unit := do + let .ok expr ← elabTerm (← `(term|∀ (p q : Prop), p → p ∧ (p ∨ q))) .none | unreachable! + Meta.forallTelescope expr $ λ _ rootTarget => do + let state ← GoalState.create rootTarget + let .success state1 _ ← state.tacticOn 0 "exact p" | unreachable! + let .success state2 _ ← state.tacticOn 1 "apply Or.inl" | unreachable! + let state' := state2.replay state state1 + return () def suite (env: Environment): List (String × IO LSpec.TestSeq) := let tests := [ @@ -274,7 +282,8 @@ def suite (env: Environment): List (String × IO LSpec.TestSeq) := ("2 < 5", test_m_couple), ("2 < 5", test_m_couple_simp), ("Proposition Generation", test_proposition_generation), - ("Partial Continuation", test_partial_continuation) + ("Partial Continuation", test_partial_continuation), + ("Branch Unification", test_branch_unification), ] tests.map (fun (name, test) => (name, proofRunner env test))