feat: Branch unification (modulo conv/calc)

This commit is contained in:
Leni Aniva 2025-06-23 13:09:08 -07:00
parent cb7c4d2723
commit 05a8e3b13c
Signed by: aniva
GPG Key ID: D5F96287843E8DFB
2 changed files with 116 additions and 29 deletions

View File

@ -185,13 +185,17 @@ protected def GoalState.getMVarEAssignment (goalState: GoalState) (mvarId: MVarI
let (expr, _) := instantiateMVarsCore (mctx := goalState.mctx) (e := expr)
return expr
deriving instance BEq for DelayedMetavarAssignment
/-- Given states `dst`, `src`, and `src'`, where `dst` and `src'` are
descendants of `src`, replay the differential `src' - src` in `dst`. Colliding
metavariable and lemma names will be automatically renamed to ensure there is no
collision. This implements branch unification. Unification might be impossible
if conflicting assignments exist. -/
if conflicting assignments exist. We also assume the monotonicity property: In a
chain of descending goal states, a mvar cannot be unassigned, and once assigned
its assignment cannot change. -/
@[export pantograph_goal_state_replay_m]
protected def GoalState.replay (dst : GoalState) (src src' : GoalState) : CoreM (Option GoalState) := do
protected def GoalState.replay (dst : GoalState) (src src' : GoalState) : CoreM GoalState := do
let srcNGen := src.coreState.ngen
let srcNGen' := src'.coreState.ngen
assert! srcNGen.namePrefix == srcNGen'.namePrefix
@ -200,23 +204,33 @@ protected def GoalState.replay (dst : GoalState) (src src' : GoalState) : CoreM
assert! src.mctx.depth == dst.mctx.depth
let diffNGenIdx := dst.coreState.ngen.idx - srcNGen.idx
--let diffMVarIds := src'.mctx.decls.map λ decl => decl.index ≥ src.mctx.mvarCounter
-- True if the name is generated after `src`
let isNewName : Name → Bool
| .num pref n =>
pref == srcNGen.namePrefix ∧ n ≥ srcNGen.idx
| _ => false
let mapId : Name → Name
| id@(.num pref n) =>
if pref == srcNGen.namePrefix ∧ n ≥ srcNGen.idx then
if isNewName id then
.num pref (n + diffNGenIdx)
else
id
| id => id
let mapExpr (e : Expr) : CoreM Expr := Core.transform e λ
| .mvar { name } => pure $ .done $ .mvar ⟨mapId name⟩
| _ => pure .continue
let rec mapLevel : Level → Level
| .succ x => .succ (mapLevel x)
| .max l1 l2 => .max (mapLevel l1) (mapLevel l2)
| .imax l1 l2 => .imax (mapLevel l1) (mapLevel l2)
| .mvar { name } => .mvar ⟨mapId name⟩
| l => l
let mapExpr (e : Expr) : CoreM Expr := Core.transform e λ
| .mvar { name } => pure $ .done $ .mvar ⟨mapId name⟩
| _ => pure .continue
let mapDelayedAssignment (d : DelayedMetavarAssignment) : CoreM DelayedMetavarAssignment := do
let { mvarIdPending, fvars } := d
return {
mvarIdPending := ⟨mapId mvarIdPending.name⟩,
fvars := ← fvars.mapM mapExpr,
}
let mapLocalDecl (ldecl : LocalDecl) : CoreM LocalDecl := do
let ldecl := ldecl.setType (← mapExpr ldecl.type)
if let .some value := ldecl.value? then
@ -245,18 +259,44 @@ protected def GoalState.replay (dst : GoalState) (src src' : GoalState) : CoreM
type := ← mapExpr decl.type,
}
return acc.insert mvarId decl
-- Merge mvar assignments
userNames := src'.mctx.userNames.foldl (init := mctx.userNames) λ acc userName mvarId =>
if acc.contains userName then
acc
else
acc.insert userName mvarId,
lAssignment := src'.mctx.lAssignment.foldl (init := mctx.lAssignment) λ acc lmvarId' l =>
let lmvarId := ⟨mapId lmvarId'.name⟩
if mctx.lAssignment.contains lmvarId then
-- Skip the intersecting assignments for now
acc
else
let l := mapLevel l
acc.insert lmvarId l,
eAssignment := ← src'.mctx.eAssignment.foldlM (init := mctx.eAssignment) λ acc mvarId' e => do
let mvarId := ⟨mapId mvarId'.name⟩
if mctx.eAssignment.contains mvarId then
-- Skip the intersecting assignments for now
return acc
else
let e ← mapExpr e
return acc.insert mvarId e,
dAssignment := ← src'.mctx.dAssignment.foldlM (init := mctx.dAssignment) λ acc mvarId' d => do
let mvarId := ⟨mapId mvarId'.name⟩
if mctx.dAssignment.contains mvarId then
return acc
else
let d ← mapDelayedAssignment d
return acc.insert mvarId d
}
let ngen := {
core.ngen with
idx := core.ngen.idx + (srcNGen'.idx - srcNGen.idx)
}
return .some {
dst with
savedState := {
dst.savedState with
term := {
savedTerm with
meta := {
-- Merge conflicting lmvar and mvar assignments using `isDefEq`
let savedMeta := {
savedMeta with
core := {
core with
@ -267,6 +307,47 @@ protected def GoalState.replay (dst : GoalState) (src src' : GoalState) : CoreM
mctx,
}
}
let m : MetaM Meta.SavedState := Meta.withMCtx mctx do
savedMeta.restore
for (lmvarId, l') in src'.mctx.lAssignment do
if isNewName lmvarId.name then
continue
let .some l ← getLevelMVarAssignment? lmvarId | continue
let l' := mapLevel l'
unless ← Meta.isLevelDefEq l l' do
throwError "Conflicting assignment of level metavariable {lmvarId.name}"
for (mvarId, e') in src'.mctx.eAssignment do
if isNewName mvarId.name then
continue
if ← mvarId.isDelayedAssigned then
throwError "Conflicting assignment of expr metavariable (e != d) {mvarId.name}"
let .some e ← getExprMVarAssignment? mvarId | continue
let e' ← mapExpr e'
unless ← Meta.isDefEq e e' do
throwError "Conflicting assignment of expr metavariable (e != e) {mvarId.name}"
for (mvarId, d') in src'.mctx.dAssignment do
if isNewName mvarId.name then
continue
if ← mvarId.isAssigned then
throwError "Conflicting assignment of expr metavariable (d != e) {mvarId.name}"
let .some d ← getDelayedMVarAssignment? mvarId | continue
unless d == d' do
throwError "Conflicting assignment of expr metavariable (d != d) {mvarId.name}"
Meta.saveState
-- FIXME: Handle calc goals
let goals :=dst.savedState.tactic.goals ++
src'.savedState.tactic.goals.map (⟨mapId ·.name⟩)
return {
dst with
savedState := {
tactic := {
goals
},
term := {
savedTerm with
meta := ← m.run',
},
},
}

View File

@ -253,8 +253,8 @@ def test_partial_continuation: TestM Unit := do
addTest $ assertUnreachable $ msg
return ()
| .ok state => pure state
addTest $ LSpec.check "(continue 2)" ((← state1b.serializeGoals (options := ← read)).map (·.target.pp?) =
#[.some "2 ≤ Nat.succ ?m", .some "Nat.succ ?m ≤ 5", .some "Nat"])
checkEq "(continue 2)" ((← state1b.serializeGoals (options := ← read)).map (·.target.pp?))
#[.some "2 ≤ Nat.succ ?m", .some "Nat.succ ?m ≤ 5", .some "Nat"]
checkTrue "(2 root)" state1b.rootExpr?.get!.hasExprMVar
-- Continuation should fail if the state does not exist:
@ -268,12 +268,18 @@ def test_partial_continuation: TestM Unit := do
return ()
def test_branch_unification : TestM Unit := do
let .ok expr ← elabTerm (← `(term|∀ (p q : Prop), p → p ∧ (p q))) .none | unreachable!
Meta.forallTelescope expr $ λ _ rootTarget => do
let .ok rootTarget ← elabTerm (← `(term|∀ (p q : Prop), p → p ∧ (p q))) .none | unreachable!
let state ← GoalState.create rootTarget
let .success state1 _ ← state.tacticOn 0 "exact p" | unreachable!
let .success state2 _ ← state.tacticOn 1 "apply Or.inl" | unreachable!
let state' := state2.replay state state1
let .success state _ ← state.tacticOn' 0 (← `(tactic|intro p q h)) | fail "intro failed to run"
let .success state _ ← state.tacticOn' 0 (← `(tactic|apply And.intro)) | fail "apply And.intro failed to run"
let .success state1 _ ← state.tacticOn' 0 (← `(tactic|exact h)) | fail "exact h failed to run"
let .success state2 _ ← state.tacticOn' 1 (← `(tactic|apply Or.inl)) | fail "apply Or.inl failed to run"
assert! state2.goals.length == 1
let state' ← state2.replay state state1
assert! state'.goals.length == 1
let .success stateT _ ← state'.tacticOn' 0 (← `(tactic|exact h)) | fail "exact h failed to run"
let .some root := stateT.rootExpr? | fail "Root expression must exist"
checkEq "(root)" (toString $ ← Meta.ppExpr root) "fun p q h => ⟨h, Or.inl h⟩"
return ()
def suite (env: Environment): List (String × IO LSpec.TestSeq) :=