feat: Branch unification (modulo conv/calc)
This commit is contained in:
parent
cb7c4d2723
commit
05a8e3b13c
|
@ -185,13 +185,17 @@ protected def GoalState.getMVarEAssignment (goalState: GoalState) (mvarId: MVarI
|
|||
let (expr, _) := instantiateMVarsCore (mctx := goalState.mctx) (e := expr)
|
||||
return expr
|
||||
|
||||
deriving instance BEq for DelayedMetavarAssignment
|
||||
|
||||
/-- Given states `dst`, `src`, and `src'`, where `dst` and `src'` are
|
||||
descendants of `src`, replay the differential `src' - src` in `dst`. Colliding
|
||||
metavariable and lemma names will be automatically renamed to ensure there is no
|
||||
collision. This implements branch unification. Unification might be impossible
|
||||
if conflicting assignments exist. -/
|
||||
if conflicting assignments exist. We also assume the monotonicity property: In a
|
||||
chain of descending goal states, a mvar cannot be unassigned, and once assigned
|
||||
its assignment cannot change. -/
|
||||
@[export pantograph_goal_state_replay_m]
|
||||
protected def GoalState.replay (dst : GoalState) (src src' : GoalState) : CoreM (Option GoalState) := do
|
||||
protected def GoalState.replay (dst : GoalState) (src src' : GoalState) : CoreM GoalState := do
|
||||
let srcNGen := src.coreState.ngen
|
||||
let srcNGen' := src'.coreState.ngen
|
||||
assert! srcNGen.namePrefix == srcNGen'.namePrefix
|
||||
|
@ -200,23 +204,33 @@ protected def GoalState.replay (dst : GoalState) (src src' : GoalState) : CoreM
|
|||
assert! src.mctx.depth == dst.mctx.depth
|
||||
|
||||
let diffNGenIdx := dst.coreState.ngen.idx - srcNGen.idx
|
||||
--let diffMVarIds := src'.mctx.decls.map λ decl => decl.index ≥ src.mctx.mvarCounter
|
||||
-- True if the name is generated after `src`
|
||||
let isNewName : Name → Bool
|
||||
| .num pref n =>
|
||||
pref == srcNGen.namePrefix ∧ n ≥ srcNGen.idx
|
||||
| _ => false
|
||||
let mapId : Name → Name
|
||||
| id@(.num pref n) =>
|
||||
if pref == srcNGen.namePrefix ∧ n ≥ srcNGen.idx then
|
||||
if isNewName id then
|
||||
.num pref (n + diffNGenIdx)
|
||||
else
|
||||
id
|
||||
| id => id
|
||||
let mapExpr (e : Expr) : CoreM Expr := Core.transform e λ
|
||||
| .mvar { name } => pure $ .done $ .mvar ⟨mapId name⟩
|
||||
| _ => pure .continue
|
||||
let rec mapLevel : Level → Level
|
||||
| .succ x => .succ (mapLevel x)
|
||||
| .max l1 l2 => .max (mapLevel l1) (mapLevel l2)
|
||||
| .imax l1 l2 => .imax (mapLevel l1) (mapLevel l2)
|
||||
| .mvar { name } => .mvar ⟨mapId name⟩
|
||||
| l => l
|
||||
let mapExpr (e : Expr) : CoreM Expr := Core.transform e λ
|
||||
| .mvar { name } => pure $ .done $ .mvar ⟨mapId name⟩
|
||||
| _ => pure .continue
|
||||
let mapDelayedAssignment (d : DelayedMetavarAssignment) : CoreM DelayedMetavarAssignment := do
|
||||
let { mvarIdPending, fvars } := d
|
||||
return {
|
||||
mvarIdPending := ⟨mapId mvarIdPending.name⟩,
|
||||
fvars := ← fvars.mapM mapExpr,
|
||||
}
|
||||
let mapLocalDecl (ldecl : LocalDecl) : CoreM LocalDecl := do
|
||||
let ldecl := ldecl.setType (← mapExpr ldecl.type)
|
||||
if let .some value := ldecl.value? then
|
||||
|
@ -245,28 +259,95 @@ protected def GoalState.replay (dst : GoalState) (src src' : GoalState) : CoreM
|
|||
type := ← mapExpr decl.type,
|
||||
}
|
||||
return acc.insert mvarId decl
|
||||
|
||||
-- Merge mvar assignments
|
||||
userNames := src'.mctx.userNames.foldl (init := mctx.userNames) λ acc userName mvarId =>
|
||||
if acc.contains userName then
|
||||
acc
|
||||
else
|
||||
acc.insert userName mvarId,
|
||||
lAssignment := src'.mctx.lAssignment.foldl (init := mctx.lAssignment) λ acc lmvarId' l =>
|
||||
let lmvarId := ⟨mapId lmvarId'.name⟩
|
||||
if mctx.lAssignment.contains lmvarId then
|
||||
-- Skip the intersecting assignments for now
|
||||
acc
|
||||
else
|
||||
let l := mapLevel l
|
||||
acc.insert lmvarId l,
|
||||
eAssignment := ← src'.mctx.eAssignment.foldlM (init := mctx.eAssignment) λ acc mvarId' e => do
|
||||
let mvarId := ⟨mapId mvarId'.name⟩
|
||||
if mctx.eAssignment.contains mvarId then
|
||||
-- Skip the intersecting assignments for now
|
||||
return acc
|
||||
else
|
||||
let e ← mapExpr e
|
||||
return acc.insert mvarId e,
|
||||
dAssignment := ← src'.mctx.dAssignment.foldlM (init := mctx.dAssignment) λ acc mvarId' d => do
|
||||
let mvarId := ⟨mapId mvarId'.name⟩
|
||||
if mctx.dAssignment.contains mvarId then
|
||||
return acc
|
||||
else
|
||||
let d ← mapDelayedAssignment d
|
||||
return acc.insert mvarId d
|
||||
}
|
||||
let ngen := {
|
||||
core.ngen with
|
||||
idx := core.ngen.idx + (srcNGen'.idx - srcNGen.idx)
|
||||
}
|
||||
return .some {
|
||||
-- Merge conflicting lmvar and mvar assignments using `isDefEq`
|
||||
|
||||
let savedMeta := {
|
||||
savedMeta with
|
||||
core := {
|
||||
core with
|
||||
ngen,
|
||||
}
|
||||
meta := {
|
||||
meta with
|
||||
mctx,
|
||||
}
|
||||
}
|
||||
let m : MetaM Meta.SavedState := Meta.withMCtx mctx do
|
||||
savedMeta.restore
|
||||
|
||||
for (lmvarId, l') in src'.mctx.lAssignment do
|
||||
if isNewName lmvarId.name then
|
||||
continue
|
||||
let .some l ← getLevelMVarAssignment? lmvarId | continue
|
||||
let l' := mapLevel l'
|
||||
unless ← Meta.isLevelDefEq l l' do
|
||||
throwError "Conflicting assignment of level metavariable {lmvarId.name}"
|
||||
for (mvarId, e') in src'.mctx.eAssignment do
|
||||
if isNewName mvarId.name then
|
||||
continue
|
||||
if ← mvarId.isDelayedAssigned then
|
||||
throwError "Conflicting assignment of expr metavariable (e != d) {mvarId.name}"
|
||||
let .some e ← getExprMVarAssignment? mvarId | continue
|
||||
let e' ← mapExpr e'
|
||||
unless ← Meta.isDefEq e e' do
|
||||
throwError "Conflicting assignment of expr metavariable (e != e) {mvarId.name}"
|
||||
for (mvarId, d') in src'.mctx.dAssignment do
|
||||
if isNewName mvarId.name then
|
||||
continue
|
||||
if ← mvarId.isAssigned then
|
||||
throwError "Conflicting assignment of expr metavariable (d != e) {mvarId.name}"
|
||||
let .some d ← getDelayedMVarAssignment? mvarId | continue
|
||||
unless d == d' do
|
||||
throwError "Conflicting assignment of expr metavariable (d != d) {mvarId.name}"
|
||||
|
||||
Meta.saveState
|
||||
-- FIXME: Handle calc goals
|
||||
let goals :=dst.savedState.tactic.goals ++
|
||||
src'.savedState.tactic.goals.map (⟨mapId ·.name⟩)
|
||||
return {
|
||||
dst with
|
||||
savedState := {
|
||||
dst.savedState with
|
||||
tactic := {
|
||||
goals
|
||||
},
|
||||
term := {
|
||||
savedTerm with
|
||||
meta := {
|
||||
savedMeta with
|
||||
core := {
|
||||
core with
|
||||
ngen,
|
||||
}
|
||||
meta := {
|
||||
meta with
|
||||
mctx,
|
||||
}
|
||||
}
|
||||
meta := ← m.run',
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
|
@ -253,8 +253,8 @@ def test_partial_continuation: TestM Unit := do
|
|||
addTest $ assertUnreachable $ msg
|
||||
return ()
|
||||
| .ok state => pure state
|
||||
addTest $ LSpec.check "(continue 2)" ((← state1b.serializeGoals (options := ← read)).map (·.target.pp?) =
|
||||
#[.some "2 ≤ Nat.succ ?m", .some "Nat.succ ?m ≤ 5", .some "Nat"])
|
||||
checkEq "(continue 2)" ((← state1b.serializeGoals (options := ← read)).map (·.target.pp?))
|
||||
#[.some "2 ≤ Nat.succ ?m", .some "Nat.succ ?m ≤ 5", .some "Nat"]
|
||||
checkTrue "(2 root)" state1b.rootExpr?.get!.hasExprMVar
|
||||
|
||||
-- Continuation should fail if the state does not exist:
|
||||
|
@ -268,13 +268,19 @@ def test_partial_continuation: TestM Unit := do
|
|||
return ()
|
||||
|
||||
def test_branch_unification : TestM Unit := do
|
||||
let .ok expr ← elabTerm (← `(term|∀ (p q : Prop), p → p ∧ (p ∨ q))) .none | unreachable!
|
||||
Meta.forallTelescope expr $ λ _ rootTarget => do
|
||||
let state ← GoalState.create rootTarget
|
||||
let .success state1 _ ← state.tacticOn 0 "exact p" | unreachable!
|
||||
let .success state2 _ ← state.tacticOn 1 "apply Or.inl" | unreachable!
|
||||
let state' := state2.replay state state1
|
||||
return ()
|
||||
let .ok rootTarget ← elabTerm (← `(term|∀ (p q : Prop), p → p ∧ (p ∨ q))) .none | unreachable!
|
||||
let state ← GoalState.create rootTarget
|
||||
let .success state _ ← state.tacticOn' 0 (← `(tactic|intro p q h)) | fail "intro failed to run"
|
||||
let .success state _ ← state.tacticOn' 0 (← `(tactic|apply And.intro)) | fail "apply And.intro failed to run"
|
||||
let .success state1 _ ← state.tacticOn' 0 (← `(tactic|exact h)) | fail "exact h failed to run"
|
||||
let .success state2 _ ← state.tacticOn' 1 (← `(tactic|apply Or.inl)) | fail "apply Or.inl failed to run"
|
||||
assert! state2.goals.length == 1
|
||||
let state' ← state2.replay state state1
|
||||
assert! state'.goals.length == 1
|
||||
let .success stateT _ ← state'.tacticOn' 0 (← `(tactic|exact h)) | fail "exact h failed to run"
|
||||
let .some root := stateT.rootExpr? | fail "Root expression must exist"
|
||||
checkEq "(root)" (toString $ ← Meta.ppExpr root) "fun p q h => ⟨h, Or.inl h⟩"
|
||||
return ()
|
||||
|
||||
def suite (env: Environment): List (String × IO LSpec.TestSeq) :=
|
||||
let tests := [
|
||||
|
|
Loading…
Reference in New Issue