feat(goal): Branch unification #217
|
@ -185,6 +185,92 @@ protected def GoalState.getMVarEAssignment (goalState: GoalState) (mvarId: MVarI
|
||||||
let (expr, _) := instantiateMVarsCore (mctx := goalState.mctx) (e := expr)
|
let (expr, _) := instantiateMVarsCore (mctx := goalState.mctx) (e := expr)
|
||||||
return expr
|
return expr
|
||||||
|
|
||||||
|
/-- Given states `dst`, `src`, and `src'`, where `dst` and `src'` are
|
||||||
|
descendants of `src`, replay the differential `src' - src` in `dst`. Colliding
|
||||||
|
metavariable and lemma names will be automatically renamed to ensure there is no
|
||||||
|
collision. This implements branch unification. Unification might be impossible
|
||||||
|
if conflicting assignments exist. -/
|
||||||
|
@[export pantograph_goal_state_replay]
|
||||||
|
protected def GoalState.replay (dst : GoalState) (src src' : GoalState) : CoreM (Option GoalState) := do
|
||||||
|
let srcNGen := src.coreState.ngen
|
||||||
|
let srcNGen' := src'.coreState.ngen
|
||||||
|
assert! srcNGen.namePrefix == srcNGen'.namePrefix
|
||||||
|
assert! srcNGen.namePrefix == dst.coreState.ngen.namePrefix
|
||||||
|
assert! src.mctx.depth == src'.mctx.depth
|
||||||
|
assert! src.mctx.depth == dst.mctx.depth
|
||||||
|
|
||||||
|
let diffNGenIdx := dst.coreState.ngen.idx - srcNGen.idx
|
||||||
|
--let diffMVarIds := src'.mctx.decls.map λ decl => decl.index ≥ src.mctx.mvarCounter
|
||||||
|
let mapId : Name → Name
|
||||||
|
| id@(.num pref n) =>
|
||||||
|
if pref == srcNGen.namePrefix ∧ n ≥ srcNGen.idx then
|
||||||
|
.num pref (n + diffNGenIdx)
|
||||||
|
else
|
||||||
|
id
|
||||||
|
| id => id
|
||||||
|
let mapExpr (e : Expr) : CoreM Expr := Core.transform e λ
|
||||||
|
| .mvar { name } => pure $ .done $ .mvar ⟨mapId name⟩
|
||||||
|
| _ => pure .continue
|
||||||
|
let rec mapLevel : Level → Level
|
||||||
|
| .succ x => .succ (mapLevel x)
|
||||||
|
| .max l1 l2 => .max (mapLevel l1) (mapLevel l2)
|
||||||
|
| .imax l1 l2 => .imax (mapLevel l1) (mapLevel l2)
|
||||||
|
| .mvar { name } => .mvar ⟨mapId name⟩
|
||||||
|
| l => l
|
||||||
|
let mapLocalDecl (ldecl : LocalDecl) : CoreM LocalDecl := do
|
||||||
|
let ldecl := ldecl.setType (← mapExpr ldecl.type)
|
||||||
|
if let .some value := ldecl.value? then
|
||||||
|
return ldecl.setValue (← mapExpr value)
|
||||||
|
else
|
||||||
|
return ldecl
|
||||||
|
|
||||||
|
let { term := savedTerm@{ meta := savedMeta@{ core, meta := meta@{ mctx, .. } }, .. }, .. } := dst.savedState
|
||||||
|
let mctx := {
|
||||||
|
mctx with
|
||||||
|
mvarCounter := mctx.mvarCounter + (src'.mctx.mvarCounter - src.mctx.mvarCounter),
|
||||||
|
lDepth := src'.mctx.lDepth.foldl (init := mctx.lDepth) λ acc lmvarId@{ name } depth =>
|
||||||
|
if src.mctx.lDepth.contains lmvarId then
|
||||||
|
acc
|
||||||
|
else
|
||||||
|
acc.insert ⟨mapId name⟩ depth
|
||||||
|
decls := ← src'.mctx.decls.foldlM (init := mctx.decls) λ acc _mvarId@{ name } decl => do
|
||||||
|
if decl.index < src.mctx.mvarCounter then
|
||||||
|
return acc
|
||||||
|
let mvarId := ⟨mapId name⟩
|
||||||
|
let decl := {
|
||||||
|
decl with
|
||||||
|
lctx := ← decl.lctx.foldlM (init := .empty) λ acc decl => do
|
||||||
|
let decl ← mapLocalDecl decl
|
||||||
|
return acc.addDecl decl,
|
||||||
|
type := ← mapExpr decl.type,
|
||||||
|
}
|
||||||
|
return acc.insert mvarId decl
|
||||||
|
}
|
||||||
|
let ngen := {
|
||||||
|
core.ngen with
|
||||||
|
idx := core.ngen.idx + (srcNGen'.idx - srcNGen.idx)
|
||||||
|
}
|
||||||
|
return .some {
|
||||||
|
dst with
|
||||||
|
savedState := {
|
||||||
|
dst.savedState with
|
||||||
|
term := {
|
||||||
|
savedTerm with
|
||||||
|
meta := {
|
||||||
|
savedMeta with
|
||||||
|
core := {
|
||||||
|
core with
|
||||||
|
ngen,
|
||||||
|
}
|
||||||
|
meta := {
|
||||||
|
meta with
|
||||||
|
mctx,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
--- Tactic execution functions ---
|
--- Tactic execution functions ---
|
||||||
|
|
||||||
-- Mimics `Elab.Term.logUnassignedUsingErrorInfos`
|
-- Mimics `Elab.Term.logUnassignedUsingErrorInfos`
|
||||||
|
|
|
@ -260,13 +260,21 @@ def test_partial_continuation: TestM Unit := do
|
||||||
-- Continuation should fail if the state does not exist:
|
-- Continuation should fail if the state does not exist:
|
||||||
match state0.resume coupled_goals with
|
match state0.resume coupled_goals with
|
||||||
| .error error => addTest $ LSpec.check "(continuation failure message)" (error = "Goals [_uniq.44, _uniq.45, _uniq.42, _uniq.51] are not in scope")
|
| .error error => addTest $ LSpec.check "(continuation failure message)" (error = "Goals [_uniq.44, _uniq.45, _uniq.42, _uniq.51] are not in scope")
|
||||||
| .ok _ => addTest $ assertUnreachable "(continuation failure)"
|
| .ok _ => fail "(continuation should fail)"
|
||||||
-- Continuation should fail if some goals have not been solved
|
-- Continuation should fail if some goals have not been solved
|
||||||
match state2.continue state1 with
|
match state2.continue state1 with
|
||||||
| .error error => addTest $ LSpec.check "(continuation failure message)" (error = "Target state has unresolved goals")
|
| .error error => addTest $ LSpec.check "(continuation failure message)" (error = "Target state has unresolved goals")
|
||||||
| .ok _ => addTest $ assertUnreachable "(continuation failure)"
|
| .ok _ => fail "(continuation should fail)"
|
||||||
return ()
|
return ()
|
||||||
|
|
||||||
|
def test_branch_unification : TestM Unit := do
|
||||||
|
let .ok expr ← elabTerm (← `(term|∀ (p q : Prop), p → p ∧ (p ∨ q))) .none | unreachable!
|
||||||
|
Meta.forallTelescope expr $ λ _ rootTarget => do
|
||||||
|
let state ← GoalState.create rootTarget
|
||||||
|
let .success state1 _ ← state.tacticOn 0 "exact p" | unreachable!
|
||||||
|
let .success state2 _ ← state.tacticOn 1 "apply Or.inl" | unreachable!
|
||||||
|
let state' := state2.replay state state1
|
||||||
|
return ()
|
||||||
|
|
||||||
def suite (env: Environment): List (String × IO LSpec.TestSeq) :=
|
def suite (env: Environment): List (String × IO LSpec.TestSeq) :=
|
||||||
let tests := [
|
let tests := [
|
||||||
|
@ -274,7 +282,8 @@ def suite (env: Environment): List (String × IO LSpec.TestSeq) :=
|
||||||
("2 < 5", test_m_couple),
|
("2 < 5", test_m_couple),
|
||||||
("2 < 5", test_m_couple_simp),
|
("2 < 5", test_m_couple_simp),
|
||||||
("Proposition Generation", test_proposition_generation),
|
("Proposition Generation", test_proposition_generation),
|
||||||
("Partial Continuation", test_partial_continuation)
|
("Partial Continuation", test_partial_continuation),
|
||||||
|
("Branch Unification", test_branch_unification),
|
||||||
]
|
]
|
||||||
tests.map (fun (name, test) => (name, proofRunner env test))
|
tests.map (fun (name, test) => (name, proofRunner env test))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue