feat: Branch unification (modulo conv/calc)

This commit is contained in:
Leni Aniva 2025-06-23 13:09:08 -07:00
parent cb7c4d2723
commit 05a8e3b13c
Signed by: aniva
GPG Key ID: D5F96287843E8DFB
2 changed files with 116 additions and 29 deletions

View File

@ -185,13 +185,17 @@ protected def GoalState.getMVarEAssignment (goalState: GoalState) (mvarId: MVarI
let (expr, _) := instantiateMVarsCore (mctx := goalState.mctx) (e := expr) let (expr, _) := instantiateMVarsCore (mctx := goalState.mctx) (e := expr)
return expr return expr
deriving instance BEq for DelayedMetavarAssignment
/-- Given states `dst`, `src`, and `src'`, where `dst` and `src'` are /-- Given states `dst`, `src`, and `src'`, where `dst` and `src'` are
descendants of `src`, replay the differential `src' - src` in `dst`. Colliding descendants of `src`, replay the differential `src' - src` in `dst`. Colliding
metavariable and lemma names will be automatically renamed to ensure there is no metavariable and lemma names will be automatically renamed to ensure there is no
collision. This implements branch unification. Unification might be impossible collision. This implements branch unification. Unification might be impossible
if conflicting assignments exist. -/ if conflicting assignments exist. We also assume the monotonicity property: In a
chain of descending goal states, a mvar cannot be unassigned, and once assigned
its assignment cannot change. -/
@[export pantograph_goal_state_replay_m] @[export pantograph_goal_state_replay_m]
protected def GoalState.replay (dst : GoalState) (src src' : GoalState) : CoreM (Option GoalState) := do protected def GoalState.replay (dst : GoalState) (src src' : GoalState) : CoreM GoalState := do
let srcNGen := src.coreState.ngen let srcNGen := src.coreState.ngen
let srcNGen' := src'.coreState.ngen let srcNGen' := src'.coreState.ngen
assert! srcNGen.namePrefix == srcNGen'.namePrefix assert! srcNGen.namePrefix == srcNGen'.namePrefix
@ -200,23 +204,33 @@ protected def GoalState.replay (dst : GoalState) (src src' : GoalState) : CoreM
assert! src.mctx.depth == dst.mctx.depth assert! src.mctx.depth == dst.mctx.depth
let diffNGenIdx := dst.coreState.ngen.idx - srcNGen.idx let diffNGenIdx := dst.coreState.ngen.idx - srcNGen.idx
--let diffMVarIds := src'.mctx.decls.map λ decl => decl.index ≥ src.mctx.mvarCounter -- True if the name is generated after `src`
let isNewName : Name → Bool
| .num pref n =>
pref == srcNGen.namePrefix ∧ n ≥ srcNGen.idx
| _ => false
let mapId : Name → Name let mapId : Name → Name
| id@(.num pref n) => | id@(.num pref n) =>
if pref == srcNGen.namePrefix ∧ n ≥ srcNGen.idx then if isNewName id then
.num pref (n + diffNGenIdx) .num pref (n + diffNGenIdx)
else else
id id
| id => id | id => id
let mapExpr (e : Expr) : CoreM Expr := Core.transform e λ
| .mvar { name } => pure $ .done $ .mvar ⟨mapId name⟩
| _ => pure .continue
let rec mapLevel : Level → Level let rec mapLevel : Level → Level
| .succ x => .succ (mapLevel x) | .succ x => .succ (mapLevel x)
| .max l1 l2 => .max (mapLevel l1) (mapLevel l2) | .max l1 l2 => .max (mapLevel l1) (mapLevel l2)
| .imax l1 l2 => .imax (mapLevel l1) (mapLevel l2) | .imax l1 l2 => .imax (mapLevel l1) (mapLevel l2)
| .mvar { name } => .mvar ⟨mapId name⟩ | .mvar { name } => .mvar ⟨mapId name⟩
| l => l | l => l
let mapExpr (e : Expr) : CoreM Expr := Core.transform e λ
| .mvar { name } => pure $ .done $ .mvar ⟨mapId name⟩
| _ => pure .continue
let mapDelayedAssignment (d : DelayedMetavarAssignment) : CoreM DelayedMetavarAssignment := do
let { mvarIdPending, fvars } := d
return {
mvarIdPending := ⟨mapId mvarIdPending.name⟩,
fvars := ← fvars.mapM mapExpr,
}
let mapLocalDecl (ldecl : LocalDecl) : CoreM LocalDecl := do let mapLocalDecl (ldecl : LocalDecl) : CoreM LocalDecl := do
let ldecl := ldecl.setType (← mapExpr ldecl.type) let ldecl := ldecl.setType (← mapExpr ldecl.type)
if let .some value := ldecl.value? then if let .some value := ldecl.value? then
@ -245,28 +259,95 @@ protected def GoalState.replay (dst : GoalState) (src src' : GoalState) : CoreM
type := ← mapExpr decl.type, type := ← mapExpr decl.type,
} }
return acc.insert mvarId decl return acc.insert mvarId decl
-- Merge mvar assignments
userNames := src'.mctx.userNames.foldl (init := mctx.userNames) λ acc userName mvarId =>
if acc.contains userName then
acc
else
acc.insert userName mvarId,
lAssignment := src'.mctx.lAssignment.foldl (init := mctx.lAssignment) λ acc lmvarId' l =>
let lmvarId := ⟨mapId lmvarId'.name⟩
if mctx.lAssignment.contains lmvarId then
-- Skip the intersecting assignments for now
acc
else
let l := mapLevel l
acc.insert lmvarId l,
eAssignment := ← src'.mctx.eAssignment.foldlM (init := mctx.eAssignment) λ acc mvarId' e => do
let mvarId := ⟨mapId mvarId'.name⟩
if mctx.eAssignment.contains mvarId then
-- Skip the intersecting assignments for now
return acc
else
let e ← mapExpr e
return acc.insert mvarId e,
dAssignment := ← src'.mctx.dAssignment.foldlM (init := mctx.dAssignment) λ acc mvarId' d => do
let mvarId := ⟨mapId mvarId'.name⟩
if mctx.dAssignment.contains mvarId then
return acc
else
let d ← mapDelayedAssignment d
return acc.insert mvarId d
} }
let ngen := { let ngen := {
core.ngen with core.ngen with
idx := core.ngen.idx + (srcNGen'.idx - srcNGen.idx) idx := core.ngen.idx + (srcNGen'.idx - srcNGen.idx)
} }
return .some { -- Merge conflicting lmvar and mvar assignments using `isDefEq`
let savedMeta := {
savedMeta with
core := {
core with
ngen,
}
meta := {
meta with
mctx,
}
}
let m : MetaM Meta.SavedState := Meta.withMCtx mctx do
savedMeta.restore
for (lmvarId, l') in src'.mctx.lAssignment do
if isNewName lmvarId.name then
continue
let .some l ← getLevelMVarAssignment? lmvarId | continue
let l' := mapLevel l'
unless ← Meta.isLevelDefEq l l' do
throwError "Conflicting assignment of level metavariable {lmvarId.name}"
for (mvarId, e') in src'.mctx.eAssignment do
if isNewName mvarId.name then
continue
if ← mvarId.isDelayedAssigned then
throwError "Conflicting assignment of expr metavariable (e != d) {mvarId.name}"
let .some e ← getExprMVarAssignment? mvarId | continue
let e' ← mapExpr e'
unless ← Meta.isDefEq e e' do
throwError "Conflicting assignment of expr metavariable (e != e) {mvarId.name}"
for (mvarId, d') in src'.mctx.dAssignment do
if isNewName mvarId.name then
continue
if ← mvarId.isAssigned then
throwError "Conflicting assignment of expr metavariable (d != e) {mvarId.name}"
let .some d ← getDelayedMVarAssignment? mvarId | continue
unless d == d' do
throwError "Conflicting assignment of expr metavariable (d != d) {mvarId.name}"
Meta.saveState
-- FIXME: Handle calc goals
let goals :=dst.savedState.tactic.goals ++
src'.savedState.tactic.goals.map (⟨mapId ·.name⟩)
return {
dst with dst with
savedState := { savedState := {
dst.savedState with tactic := {
goals
},
term := { term := {
savedTerm with savedTerm with
meta := { meta := ← m.run',
savedMeta with
core := {
core with
ngen,
}
meta := {
meta with
mctx,
}
}
}, },
}, },
} }

View File

@ -253,8 +253,8 @@ def test_partial_continuation: TestM Unit := do
addTest $ assertUnreachable $ msg addTest $ assertUnreachable $ msg
return () return ()
| .ok state => pure state | .ok state => pure state
addTest $ LSpec.check "(continue 2)" ((← state1b.serializeGoals (options := ← read)).map (·.target.pp?) = checkEq "(continue 2)" ((← state1b.serializeGoals (options := ← read)).map (·.target.pp?))
#[.some "2 ≤ Nat.succ ?m", .some "Nat.succ ?m ≤ 5", .some "Nat"]) #[.some "2 ≤ Nat.succ ?m", .some "Nat.succ ?m ≤ 5", .some "Nat"]
checkTrue "(2 root)" state1b.rootExpr?.get!.hasExprMVar checkTrue "(2 root)" state1b.rootExpr?.get!.hasExprMVar
-- Continuation should fail if the state does not exist: -- Continuation should fail if the state does not exist:
@ -268,13 +268,19 @@ def test_partial_continuation: TestM Unit := do
return () return ()
def test_branch_unification : TestM Unit := do def test_branch_unification : TestM Unit := do
let .ok expr ← elabTerm (← `(term|∀ (p q : Prop), p → p ∧ (p q))) .none | unreachable! let .ok rootTarget ← elabTerm (← `(term|∀ (p q : Prop), p → p ∧ (p q))) .none | unreachable!
Meta.forallTelescope expr $ λ _ rootTarget => do let state ← GoalState.create rootTarget
let state ← GoalState.create rootTarget let .success state _ ← state.tacticOn' 0 (← `(tactic|intro p q h)) | fail "intro failed to run"
let .success state1 _ ← state.tacticOn 0 "exact p" | unreachable! let .success state _ ← state.tacticOn' 0 (← `(tactic|apply And.intro)) | fail "apply And.intro failed to run"
let .success state2 _ ← state.tacticOn 1 "apply Or.inl" | unreachable! let .success state1 _ ← state.tacticOn' 0 (← `(tactic|exact h)) | fail "exact h failed to run"
let state' := state2.replay state state1 let .success state2 _ ← state.tacticOn' 1 (← `(tactic|apply Or.inl)) | fail "apply Or.inl failed to run"
return () assert! state2.goals.length == 1
let state' ← state2.replay state state1
assert! state'.goals.length == 1
let .success stateT _ ← state'.tacticOn' 0 (← `(tactic|exact h)) | fail "exact h failed to run"
let .some root := stateT.rootExpr? | fail "Root expression must exist"
checkEq "(root)" (toString $ ← Meta.ppExpr root) "fun p q h => ⟨h, Or.inl h⟩"
return ()
def suite (env: Environment): List (String × IO LSpec.TestSeq) := def suite (env: Environment): List (String × IO LSpec.TestSeq) :=
let tests := [ let tests := [