feat(goal): Branch unification #217
|
@ -185,6 +185,92 @@ protected def GoalState.getMVarEAssignment (goalState: GoalState) (mvarId: MVarI
|
|||
let (expr, _) := instantiateMVarsCore (mctx := goalState.mctx) (e := expr)
|
||||
return expr
|
||||
|
||||
/-- Given states `dst`, `src`, and `src'`, where `dst` and `src'` are
|
||||
descendants of `src`, replay the differential `src' - src` in `dst`. Colliding
|
||||
metavariable and lemma names will be automatically renamed to ensure there is no
|
||||
collision. This implements branch unification. Unification might be impossible
|
||||
if conflicting assignments exist. -/
|
||||
@[export pantograph_goal_state_replay]
|
||||
protected def GoalState.replay (dst : GoalState) (src src' : GoalState) : CoreM (Option GoalState) := do
|
||||
let srcNGen := src.coreState.ngen
|
||||
let srcNGen' := src'.coreState.ngen
|
||||
assert! srcNGen.namePrefix == srcNGen'.namePrefix
|
||||
assert! srcNGen.namePrefix == dst.coreState.ngen.namePrefix
|
||||
assert! src.mctx.depth == src'.mctx.depth
|
||||
assert! src.mctx.depth == dst.mctx.depth
|
||||
|
||||
let diffNGenIdx := dst.coreState.ngen.idx - srcNGen.idx
|
||||
--let diffMVarIds := src'.mctx.decls.map λ decl => decl.index ≥ src.mctx.mvarCounter
|
||||
let mapId : Name → Name
|
||||
| id@(.num pref n) =>
|
||||
if pref == srcNGen.namePrefix ∧ n ≥ srcNGen.idx then
|
||||
.num pref (n + diffNGenIdx)
|
||||
else
|
||||
id
|
||||
| id => id
|
||||
let mapExpr (e : Expr) : CoreM Expr := Core.transform e λ
|
||||
| .mvar { name } => pure $ .done $ .mvar ⟨mapId name⟩
|
||||
| _ => pure .continue
|
||||
let rec mapLevel : Level → Level
|
||||
| .succ x => .succ (mapLevel x)
|
||||
| .max l1 l2 => .max (mapLevel l1) (mapLevel l2)
|
||||
| .imax l1 l2 => .imax (mapLevel l1) (mapLevel l2)
|
||||
| .mvar { name } => .mvar ⟨mapId name⟩
|
||||
| l => l
|
||||
let mapLocalDecl (ldecl : LocalDecl) : CoreM LocalDecl := do
|
||||
let ldecl := ldecl.setType (← mapExpr ldecl.type)
|
||||
if let .some value := ldecl.value? then
|
||||
return ldecl.setValue (← mapExpr value)
|
||||
else
|
||||
return ldecl
|
||||
|
||||
let { term := savedTerm@{ meta := savedMeta@{ core, meta := meta@{ mctx, .. } }, .. }, .. } := dst.savedState
|
||||
let mctx := {
|
||||
mctx with
|
||||
mvarCounter := mctx.mvarCounter + (src'.mctx.mvarCounter - src.mctx.mvarCounter),
|
||||
lDepth := src'.mctx.lDepth.foldl (init := mctx.lDepth) λ acc lmvarId@{ name } depth =>
|
||||
if src.mctx.lDepth.contains lmvarId then
|
||||
acc
|
||||
else
|
||||
acc.insert ⟨mapId name⟩ depth
|
||||
decls := ← src'.mctx.decls.foldlM (init := mctx.decls) λ acc _mvarId@{ name } decl => do
|
||||
if decl.index < src.mctx.mvarCounter then
|
||||
return acc
|
||||
let mvarId := ⟨mapId name⟩
|
||||
let decl := {
|
||||
decl with
|
||||
lctx := ← decl.lctx.foldlM (init := .empty) λ acc decl => do
|
||||
let decl ← mapLocalDecl decl
|
||||
return acc.addDecl decl,
|
||||
type := ← mapExpr decl.type,
|
||||
}
|
||||
return acc.insert mvarId decl
|
||||
}
|
||||
let ngen := {
|
||||
core.ngen with
|
||||
idx := core.ngen.idx + (srcNGen'.idx - srcNGen.idx)
|
||||
}
|
||||
return .some {
|
||||
dst with
|
||||
savedState := {
|
||||
dst.savedState with
|
||||
term := {
|
||||
savedTerm with
|
||||
meta := {
|
||||
savedMeta with
|
||||
core := {
|
||||
core with
|
||||
ngen,
|
||||
}
|
||||
meta := {
|
||||
meta with
|
||||
mctx,
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
--- Tactic execution functions ---
|
||||
|
||||
-- Mimics `Elab.Term.logUnassignedUsingErrorInfos`
|
||||
|
|
|
@ -260,13 +260,21 @@ def test_partial_continuation: TestM Unit := do
|
|||
-- Continuation should fail if the state does not exist:
|
||||
match state0.resume coupled_goals with
|
||||
| .error error => addTest $ LSpec.check "(continuation failure message)" (error = "Goals [_uniq.44, _uniq.45, _uniq.42, _uniq.51] are not in scope")
|
||||
| .ok _ => addTest $ assertUnreachable "(continuation failure)"
|
||||
| .ok _ => fail "(continuation should fail)"
|
||||
-- Continuation should fail if some goals have not been solved
|
||||
match state2.continue state1 with
|
||||
| .error error => addTest $ LSpec.check "(continuation failure message)" (error = "Target state has unresolved goals")
|
||||
| .ok _ => addTest $ assertUnreachable "(continuation failure)"
|
||||
| .ok _ => fail "(continuation should fail)"
|
||||
return ()
|
||||
|
||||
def test_branch_unification : TestM Unit := do
|
||||
let .ok expr ← elabTerm (← `(term|∀ (p q : Prop), p → p ∧ (p ∨ q))) .none | unreachable!
|
||||
Meta.forallTelescope expr $ λ _ rootTarget => do
|
||||
let state ← GoalState.create rootTarget
|
||||
let .success state1 _ ← state.tacticOn 0 "exact p" | unreachable!
|
||||
let .success state2 _ ← state.tacticOn 1 "apply Or.inl" | unreachable!
|
||||
let state' := state2.replay state state1
|
||||
return ()
|
||||
|
||||
def suite (env: Environment): List (String × IO LSpec.TestSeq) :=
|
||||
let tests := [
|
||||
|
@ -274,7 +282,8 @@ def suite (env: Environment): List (String × IO LSpec.TestSeq) :=
|
|||
("2 < 5", test_m_couple),
|
||||
("2 < 5", test_m_couple_simp),
|
||||
("Proposition Generation", test_proposition_generation),
|
||||
("Partial Continuation", test_partial_continuation)
|
||||
("Partial Continuation", test_partial_continuation),
|
||||
("Branch Unification", test_branch_unification),
|
||||
]
|
||||
tests.map (fun (name, test) => (name, proofRunner env test))
|
||||
|
||||
|
|
Loading…
Reference in New Issue