feat(goal): Branch unification #217
|
@ -185,13 +185,17 @@ protected def GoalState.getMVarEAssignment (goalState: GoalState) (mvarId: MVarI
|
||||||
let (expr, _) := instantiateMVarsCore (mctx := goalState.mctx) (e := expr)
|
let (expr, _) := instantiateMVarsCore (mctx := goalState.mctx) (e := expr)
|
||||||
return expr
|
return expr
|
||||||
|
|
||||||
|
deriving instance BEq for DelayedMetavarAssignment
|
||||||
|
|
||||||
/-- Given states `dst`, `src`, and `src'`, where `dst` and `src'` are
|
/-- Given states `dst`, `src`, and `src'`, where `dst` and `src'` are
|
||||||
descendants of `src`, replay the differential `src' - src` in `dst`. Colliding
|
descendants of `src`, replay the differential `src' - src` in `dst`. Colliding
|
||||||
metavariable and lemma names will be automatically renamed to ensure there is no
|
metavariable and lemma names will be automatically renamed to ensure there is no
|
||||||
collision. This implements branch unification. Unification might be impossible
|
collision. This implements branch unification. Unification might be impossible
|
||||||
if conflicting assignments exist. -/
|
if conflicting assignments exist. We also assume the monotonicity property: In a
|
||||||
|
chain of descending goal states, a mvar cannot be unassigned, and once assigned
|
||||||
|
its assignment cannot change. -/
|
||||||
@[export pantograph_goal_state_replay_m]
|
@[export pantograph_goal_state_replay_m]
|
||||||
protected def GoalState.replay (dst : GoalState) (src src' : GoalState) : CoreM (Option GoalState) := do
|
protected def GoalState.replay (dst : GoalState) (src src' : GoalState) : CoreM GoalState := do
|
||||||
let srcNGen := src.coreState.ngen
|
let srcNGen := src.coreState.ngen
|
||||||
let srcNGen' := src'.coreState.ngen
|
let srcNGen' := src'.coreState.ngen
|
||||||
assert! srcNGen.namePrefix == srcNGen'.namePrefix
|
assert! srcNGen.namePrefix == srcNGen'.namePrefix
|
||||||
|
@ -200,23 +204,33 @@ protected def GoalState.replay (dst : GoalState) (src src' : GoalState) : CoreM
|
||||||
assert! src.mctx.depth == dst.mctx.depth
|
assert! src.mctx.depth == dst.mctx.depth
|
||||||
|
|
||||||
let diffNGenIdx := dst.coreState.ngen.idx - srcNGen.idx
|
let diffNGenIdx := dst.coreState.ngen.idx - srcNGen.idx
|
||||||
--let diffMVarIds := src'.mctx.decls.map λ decl => decl.index ≥ src.mctx.mvarCounter
|
-- True if the name is generated after `src`
|
||||||
|
let isNewName : Name → Bool
|
||||||
|
| .num pref n =>
|
||||||
|
pref == srcNGen.namePrefix ∧ n ≥ srcNGen.idx
|
||||||
|
| _ => false
|
||||||
let mapId : Name → Name
|
let mapId : Name → Name
|
||||||
| id@(.num pref n) =>
|
| id@(.num pref n) =>
|
||||||
if pref == srcNGen.namePrefix ∧ n ≥ srcNGen.idx then
|
if isNewName id then
|
||||||
.num pref (n + diffNGenIdx)
|
.num pref (n + diffNGenIdx)
|
||||||
else
|
else
|
||||||
id
|
id
|
||||||
| id => id
|
| id => id
|
||||||
let mapExpr (e : Expr) : CoreM Expr := Core.transform e λ
|
|
||||||
| .mvar { name } => pure $ .done $ .mvar ⟨mapId name⟩
|
|
||||||
| _ => pure .continue
|
|
||||||
let rec mapLevel : Level → Level
|
let rec mapLevel : Level → Level
|
||||||
| .succ x => .succ (mapLevel x)
|
| .succ x => .succ (mapLevel x)
|
||||||
| .max l1 l2 => .max (mapLevel l1) (mapLevel l2)
|
| .max l1 l2 => .max (mapLevel l1) (mapLevel l2)
|
||||||
| .imax l1 l2 => .imax (mapLevel l1) (mapLevel l2)
|
| .imax l1 l2 => .imax (mapLevel l1) (mapLevel l2)
|
||||||
| .mvar { name } => .mvar ⟨mapId name⟩
|
| .mvar { name } => .mvar ⟨mapId name⟩
|
||||||
| l => l
|
| l => l
|
||||||
|
let mapExpr (e : Expr) : CoreM Expr := Core.transform e λ
|
||||||
|
| .mvar { name } => pure $ .done $ .mvar ⟨mapId name⟩
|
||||||
|
| _ => pure .continue
|
||||||
|
let mapDelayedAssignment (d : DelayedMetavarAssignment) : CoreM DelayedMetavarAssignment := do
|
||||||
|
let { mvarIdPending, fvars } := d
|
||||||
|
return {
|
||||||
|
mvarIdPending := ⟨mapId mvarIdPending.name⟩,
|
||||||
|
fvars := ← fvars.mapM mapExpr,
|
||||||
|
}
|
||||||
let mapLocalDecl (ldecl : LocalDecl) : CoreM LocalDecl := do
|
let mapLocalDecl (ldecl : LocalDecl) : CoreM LocalDecl := do
|
||||||
let ldecl := ldecl.setType (← mapExpr ldecl.type)
|
let ldecl := ldecl.setType (← mapExpr ldecl.type)
|
||||||
if let .some value := ldecl.value? then
|
if let .some value := ldecl.value? then
|
||||||
|
@ -245,18 +259,44 @@ protected def GoalState.replay (dst : GoalState) (src src' : GoalState) : CoreM
|
||||||
type := ← mapExpr decl.type,
|
type := ← mapExpr decl.type,
|
||||||
}
|
}
|
||||||
return acc.insert mvarId decl
|
return acc.insert mvarId decl
|
||||||
|
|
||||||
|
-- Merge mvar assignments
|
||||||
|
userNames := src'.mctx.userNames.foldl (init := mctx.userNames) λ acc userName mvarId =>
|
||||||
|
if acc.contains userName then
|
||||||
|
acc
|
||||||
|
else
|
||||||
|
acc.insert userName mvarId,
|
||||||
|
lAssignment := src'.mctx.lAssignment.foldl (init := mctx.lAssignment) λ acc lmvarId' l =>
|
||||||
|
let lmvarId := ⟨mapId lmvarId'.name⟩
|
||||||
|
if mctx.lAssignment.contains lmvarId then
|
||||||
|
-- Skip the intersecting assignments for now
|
||||||
|
acc
|
||||||
|
else
|
||||||
|
let l := mapLevel l
|
||||||
|
acc.insert lmvarId l,
|
||||||
|
eAssignment := ← src'.mctx.eAssignment.foldlM (init := mctx.eAssignment) λ acc mvarId' e => do
|
||||||
|
let mvarId := ⟨mapId mvarId'.name⟩
|
||||||
|
if mctx.eAssignment.contains mvarId then
|
||||||
|
-- Skip the intersecting assignments for now
|
||||||
|
return acc
|
||||||
|
else
|
||||||
|
let e ← mapExpr e
|
||||||
|
return acc.insert mvarId e,
|
||||||
|
dAssignment := ← src'.mctx.dAssignment.foldlM (init := mctx.dAssignment) λ acc mvarId' d => do
|
||||||
|
let mvarId := ⟨mapId mvarId'.name⟩
|
||||||
|
if mctx.dAssignment.contains mvarId then
|
||||||
|
return acc
|
||||||
|
else
|
||||||
|
let d ← mapDelayedAssignment d
|
||||||
|
return acc.insert mvarId d
|
||||||
}
|
}
|
||||||
let ngen := {
|
let ngen := {
|
||||||
core.ngen with
|
core.ngen with
|
||||||
idx := core.ngen.idx + (srcNGen'.idx - srcNGen.idx)
|
idx := core.ngen.idx + (srcNGen'.idx - srcNGen.idx)
|
||||||
}
|
}
|
||||||
return .some {
|
-- Merge conflicting lmvar and mvar assignments using `isDefEq`
|
||||||
dst with
|
|
||||||
savedState := {
|
let savedMeta := {
|
||||||
dst.savedState with
|
|
||||||
term := {
|
|
||||||
savedTerm with
|
|
||||||
meta := {
|
|
||||||
savedMeta with
|
savedMeta with
|
||||||
core := {
|
core := {
|
||||||
core with
|
core with
|
||||||
|
@ -267,6 +307,47 @@ protected def GoalState.replay (dst : GoalState) (src src' : GoalState) : CoreM
|
||||||
mctx,
|
mctx,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
let m : MetaM Meta.SavedState := Meta.withMCtx mctx do
|
||||||
|
savedMeta.restore
|
||||||
|
|
||||||
|
for (lmvarId, l') in src'.mctx.lAssignment do
|
||||||
|
if isNewName lmvarId.name then
|
||||||
|
continue
|
||||||
|
let .some l ← getLevelMVarAssignment? lmvarId | continue
|
||||||
|
let l' := mapLevel l'
|
||||||
|
unless ← Meta.isLevelDefEq l l' do
|
||||||
|
throwError "Conflicting assignment of level metavariable {lmvarId.name}"
|
||||||
|
for (mvarId, e') in src'.mctx.eAssignment do
|
||||||
|
if isNewName mvarId.name then
|
||||||
|
continue
|
||||||
|
if ← mvarId.isDelayedAssigned then
|
||||||
|
throwError "Conflicting assignment of expr metavariable (e != d) {mvarId.name}"
|
||||||
|
let .some e ← getExprMVarAssignment? mvarId | continue
|
||||||
|
let e' ← mapExpr e'
|
||||||
|
unless ← Meta.isDefEq e e' do
|
||||||
|
throwError "Conflicting assignment of expr metavariable (e != e) {mvarId.name}"
|
||||||
|
for (mvarId, d') in src'.mctx.dAssignment do
|
||||||
|
if isNewName mvarId.name then
|
||||||
|
continue
|
||||||
|
if ← mvarId.isAssigned then
|
||||||
|
throwError "Conflicting assignment of expr metavariable (d != e) {mvarId.name}"
|
||||||
|
let .some d ← getDelayedMVarAssignment? mvarId | continue
|
||||||
|
unless d == d' do
|
||||||
|
throwError "Conflicting assignment of expr metavariable (d != d) {mvarId.name}"
|
||||||
|
|
||||||
|
Meta.saveState
|
||||||
|
-- FIXME: Handle calc goals
|
||||||
|
let goals :=dst.savedState.tactic.goals ++
|
||||||
|
src'.savedState.tactic.goals.map (⟨mapId ·.name⟩)
|
||||||
|
return {
|
||||||
|
dst with
|
||||||
|
savedState := {
|
||||||
|
tactic := {
|
||||||
|
goals
|
||||||
|
},
|
||||||
|
term := {
|
||||||
|
savedTerm with
|
||||||
|
meta := ← m.run',
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
|
@ -253,8 +253,8 @@ def test_partial_continuation: TestM Unit := do
|
||||||
addTest $ assertUnreachable $ msg
|
addTest $ assertUnreachable $ msg
|
||||||
return ()
|
return ()
|
||||||
| .ok state => pure state
|
| .ok state => pure state
|
||||||
addTest $ LSpec.check "(continue 2)" ((← state1b.serializeGoals (options := ← read)).map (·.target.pp?) =
|
checkEq "(continue 2)" ((← state1b.serializeGoals (options := ← read)).map (·.target.pp?))
|
||||||
#[.some "2 ≤ Nat.succ ?m", .some "Nat.succ ?m ≤ 5", .some "Nat"])
|
#[.some "2 ≤ Nat.succ ?m", .some "Nat.succ ?m ≤ 5", .some "Nat"]
|
||||||
checkTrue "(2 root)" state1b.rootExpr?.get!.hasExprMVar
|
checkTrue "(2 root)" state1b.rootExpr?.get!.hasExprMVar
|
||||||
|
|
||||||
-- Continuation should fail if the state does not exist:
|
-- Continuation should fail if the state does not exist:
|
||||||
|
@ -268,12 +268,18 @@ def test_partial_continuation: TestM Unit := do
|
||||||
return ()
|
return ()
|
||||||
|
|
||||||
def test_branch_unification : TestM Unit := do
|
def test_branch_unification : TestM Unit := do
|
||||||
let .ok expr ← elabTerm (← `(term|∀ (p q : Prop), p → p ∧ (p ∨ q))) .none | unreachable!
|
let .ok rootTarget ← elabTerm (← `(term|∀ (p q : Prop), p → p ∧ (p ∨ q))) .none | unreachable!
|
||||||
Meta.forallTelescope expr $ λ _ rootTarget => do
|
|
||||||
let state ← GoalState.create rootTarget
|
let state ← GoalState.create rootTarget
|
||||||
let .success state1 _ ← state.tacticOn 0 "exact p" | unreachable!
|
let .success state _ ← state.tacticOn' 0 (← `(tactic|intro p q h)) | fail "intro failed to run"
|
||||||
let .success state2 _ ← state.tacticOn 1 "apply Or.inl" | unreachable!
|
let .success state _ ← state.tacticOn' 0 (← `(tactic|apply And.intro)) | fail "apply And.intro failed to run"
|
||||||
let state' := state2.replay state state1
|
let .success state1 _ ← state.tacticOn' 0 (← `(tactic|exact h)) | fail "exact h failed to run"
|
||||||
|
let .success state2 _ ← state.tacticOn' 1 (← `(tactic|apply Or.inl)) | fail "apply Or.inl failed to run"
|
||||||
|
assert! state2.goals.length == 1
|
||||||
|
let state' ← state2.replay state state1
|
||||||
|
assert! state'.goals.length == 1
|
||||||
|
let .success stateT _ ← state'.tacticOn' 0 (← `(tactic|exact h)) | fail "exact h failed to run"
|
||||||
|
let .some root := stateT.rootExpr? | fail "Root expression must exist"
|
||||||
|
checkEq "(root)" (toString $ ← Meta.ppExpr root) "fun p q h => ⟨h, Or.inl h⟩"
|
||||||
return ()
|
return ()
|
||||||
|
|
||||||
def suite (env: Environment): List (String × IO LSpec.TestSeq) :=
|
def suite (env: Environment): List (String × IO LSpec.TestSeq) :=
|
||||||
|
|
Loading…
Reference in New Issue