fix: Extracting `sorry`s from coupled goals

This commit is contained in:
Leni Aniva 2024-10-03 11:35:54 -07:00
parent 143cd289bb
commit 530a1a1a97
Signed by: aniva
GPG Key ID: 4D9B1C8D10EA4C50
5 changed files with 221 additions and 164 deletions

View File

@ -60,12 +60,13 @@ partial def instantiateDelayedMVars (eOrig: Expr) : MetaM Expr := do
-- nested mvars.
mvarId.setKind .syntheticOpaque
mvarId.withContext do
let lctx ← MonadLCtx.getLCtx
if mvarDecl.lctx.any (λ decl => !lctx.contains decl.fvarId) then
let violations := mvarDecl.lctx.decls.foldl (λ acc decl? => match decl? with
| .some decl => if lctx.contains decl.fvarId then acc else acc ++ [decl.fvarId.name]
| .none => acc) []
panic! s!"Local context variable violation: {violations}"
panic! s!"In the context of {mvarId.name}, there are local context variable violations: {violations}"
if let .some assign ← getExprMVarAssignment? mvarId then
--IO.println s!"{padding}├A ?{mvarId.name}"

View File

@ -1,4 +1,4 @@
/- Adapted from lean-training-data by semorrison -/
import Pantograph.Protocol
import Pantograph.Frontend.Basic
import Pantograph.Frontend.Elab
import Pantograph.Frontend.MetaTranslate

View File

@ -3,9 +3,10 @@ import Lean.Elab.Import
import Lean.Elab.Command
import Lean.Elab.InfoTree
import Pantograph.Protocol
import Pantograph.Frontend.Basic
import Pantograph.Frontend.MetaTranslate
import Pantograph.Goal
import Pantograph.Protocol
open Lean
@ -179,117 +180,6 @@ def collectSorrys (step: CompilationStep) : List InfoWithContext :=
step.trees.bind collectSorrysInTree
namespace MetaTranslate
structure Context where
sourceMCtx : MetavarContext := {}
sourceLCtx : LocalContext := {}
structure State where
-- Stores mapping from old to new mvar/fvars
mvarMap: HashMap MVarId MVarId := {}
fvarMap: HashMap FVarId FVarId := {}
/-
Monadic state for translating a frozen meta state. The underlying `MetaM`
operates in the "target" context and state.
-/
abbrev MetaTranslateM := ReaderT Context StateRefT State MetaM
def getSourceLCtx : MetaTranslateM LocalContext := do pure (← read).sourceLCtx
def getSourceMCtx : MetaTranslateM MetavarContext := do pure (← read).sourceMCtx
def addTranslatedFVar (src dst: FVarId) : MetaTranslateM Unit := do
let state ← get
set { state with fvarMap := state.fvarMap.insert src dst }
def addTranslatedMVar (src dst: MVarId) : MetaTranslateM Unit := do
let state ← get
set { state with mvarMap := state.mvarMap.insert src dst }
def resetFVarMap : MetaTranslateM Unit := do
let state ← get
set { state with fvarMap := {} }
private partial def translateExpr (srcExpr: Expr) : MetaTranslateM Expr := do
let (srcExpr, _) := instantiateMVarsCore (mctx := ← getSourceMCtx) srcExpr
--IO.println s!"Transform src: {srcExpr}"
let result ← Core.transform srcExpr λ e => do
let state ← get
match e with
| .fvar fvarId =>
let .some fvarId' := state.fvarMap.find? fvarId | panic! s!"FVar id not registered: {fvarId.name}"
return .done $ .fvar fvarId'
| .mvar mvarId => do
match state.mvarMap.find? mvarId with
| .some mvarId' => do
return .done $ .mvar mvarId'
| .none => do
--let t := (← getSourceMCtx).findDecl? mvarId |>.get!.type
--let t' ← translateExpr t
let mvar' ← Meta.mkFreshExprMVar .none
addTranslatedMVar mvarId mvar'.mvarId!
return .done mvar'
| _ => return .continue
try
Meta.check result
catch ex =>
panic! s!"Check failed: {← ex.toMessageData.toString}"
return result
def translateLocalDecl (srcLocalDecl: LocalDecl) : MetaTranslateM LocalDecl := do
let fvarId ← mkFreshFVarId
addTranslatedFVar srcLocalDecl.fvarId fvarId
match srcLocalDecl with
| .cdecl index _ userName type bi kind => do
--IO.println s!"[CD] {userName} {toString type}"
return .cdecl index fvarId userName (← translateExpr type) bi kind
| .ldecl index _ userName type value nonDep kind => do
--IO.println s!"[LD] {toString type} := {toString value}"
return .ldecl index fvarId userName (← translateExpr type) (← translateExpr value) nonDep kind
def translateLCtx : MetaTranslateM LocalContext := do
resetFVarMap
(← getSourceLCtx).foldlM (λ lctx srcLocalDecl => do
let localDecl ← Meta.withLCtx lctx #[] do translateLocalDecl srcLocalDecl
pure $ lctx.addDecl localDecl
) (← MonadLCtx.getLCtx)
def translateMVarId (srcMVarId: MVarId) : MetaTranslateM MVarId := do
let srcDecl := (← getSourceMCtx).findDecl? srcMVarId |>.get!
let mvar ← withTheReader Context (λ ctx => { ctx with sourceLCtx := srcDecl.lctx }) do
let lctx' ← translateLCtx
Meta.withLCtx lctx' #[] do
let target' ← translateExpr srcDecl.type
Meta.mkFreshExprSyntheticOpaqueMVar target'
addTranslatedMVar srcMVarId mvar.mvarId!
return mvar.mvarId!
def translateMVarFromTermInfo (termInfo : Elab.TermInfo) (context? : Option Elab.ContextInfo)
: MetaM MVarId := do
let trM : MetaTranslateM MVarId := do
let type := termInfo.expectedType?.get!
let lctx' ← translateLCtx
let mvar ← Meta.withLCtx lctx' #[] do
let type' ← translateExpr type
Meta.mkFreshExprSyntheticOpaqueMVar type'
return mvar.mvarId!
trM.run {
sourceMCtx := context?.map (·.mctx) |>.getD {},
sourceLCtx := termInfo.lctx } |>.run' {}
def translateMVarFromTacticInfoBefore (tacticInfo : Elab.TacticInfo) (_context? : Option Elab.ContextInfo)
: MetaM (List MVarId) := do
let trM : MetaTranslateM (List MVarId) := do
tacticInfo.goalsBefore.mapM translateMVarId
trM.run {
sourceMCtx := tacticInfo.mctxBefore
} |>.run' {}
end MetaTranslate
export MetaTranslate (MetaTranslateM)
/--
Since we cannot directly merge `MetavarContext`s, we have to get creative. This
@ -299,7 +189,7 @@ the current `MetavarContext`.
@[export pantograph_frontend_sorrys_to_goal_state]
def sorrysToGoalState (sorrys : List InfoWithContext) : MetaM GoalState := do
assert! !sorrys.isEmpty
let goals sorrys.mapM λ i => do
let goalsM := sorrys.mapM λ i => do
match i.info with
| .ofTermInfo termInfo => do
let mvarId ← MetaTranslate.translateMVarFromTermInfo termInfo i.context?
@ -307,7 +197,7 @@ def sorrysToGoalState (sorrys : List InfoWithContext) : MetaM GoalState := do
| .ofTacticInfo tacticInfo => do
MetaTranslate.translateMVarFromTacticInfoBefore tacticInfo i.context?
| _ => panic! "Invalid info"
let goals := goals.bind id
let goals := (← goalsM.run {} |>.run' {}).bind id
let root := match goals with
| [] => panic! "This function cannot be called on an empty list"
| [g] => g

View File

@ -0,0 +1,129 @@
import Lean.Meta
open Lean
namespace Pantograph.Frontend
namespace MetaTranslate
structure Context where
sourceMCtx : MetavarContext := {}
sourceLCtx : LocalContext := {}
abbrev FVarMap := HashMap FVarId FVarId
structure State where
-- Stores mapping from old to new mvar/fvars
mvarMap: HashMap MVarId MVarId := {}
fvarMap: HashMap FVarId FVarId := {}
/-
Monadic state for translating a frozen meta state. The underlying `MetaM`
operates in the "target" context and state.
-/
abbrev MetaTranslateM := ReaderT Context StateRefT State MetaM
def getSourceLCtx : MetaTranslateM LocalContext := do pure (← read).sourceLCtx
def getSourceMCtx : MetaTranslateM MetavarContext := do pure (← read).sourceMCtx
def addTranslatedFVar (src dst: FVarId) : MetaTranslateM Unit := do
modifyGet λ state => ((), { state with fvarMap := state.fvarMap.insert src dst })
def addTranslatedMVar (src dst: MVarId) : MetaTranslateM Unit := do
modifyGet λ state => ((), { state with mvarMap := state.mvarMap.insert src dst })
def saveFVarMap : MetaTranslateM FVarMap := do
return (← get).fvarMap
def restoreFVarMap (map: FVarMap) : MetaTranslateM Unit := do
modifyGet λ state => ((), { state with fvarMap := map })
def resetFVarMap : MetaTranslateM Unit := do
modifyGet λ state => ((), { state with fvarMap := {} })
mutual
private partial def translateExpr (srcExpr: Expr) : MetaTranslateM Expr := do
let sourceMCtx ← getSourceMCtx
let (srcExpr, _) := instantiateMVarsCore (mctx := sourceMCtx) srcExpr
--IO.println s!"Transform src: {srcExpr}"
let result ← Core.transform srcExpr λ e => do
let state ← get
match e with
| .fvar fvarId =>
let .some fvarId' := state.fvarMap.find? fvarId | panic! s!"FVar id not registered: {fvarId.name}"
assert! (← getLCtx).contains fvarId'
return .done $ .fvar fvarId'
| .mvar mvarId => do
assert! !(sourceMCtx.dAssignment.contains mvarId)
assert! !(sourceMCtx.eAssignment.contains mvarId)
match state.mvarMap.find? mvarId with
| .some mvarId' => do
return .done $ .mvar mvarId'
| .none => do
-- Entering another LCtx, must save the current one
let fvarMap ← saveFVarMap
let mvarId' ← translateMVarId mvarId
restoreFVarMap fvarMap
return .done $ .mvar mvarId'
| _ => return .continue
Meta.check result
return result
partial def translateLocalInstance (srcInstance: LocalInstance) : MetaTranslateM LocalInstance := do
return {
className := srcInstance.className,
fvar := ← translateExpr srcInstance.fvar
}
partial def translateLocalDecl (srcLocalDecl: LocalDecl) : MetaTranslateM LocalDecl := do
let fvarId ← mkFreshFVarId
addTranslatedFVar srcLocalDecl.fvarId fvarId
match srcLocalDecl with
| .cdecl index _ userName type bi kind => do
--IO.println s!"[CD] {userName} {toString type}"
return .cdecl index fvarId userName (← translateExpr type) bi kind
| .ldecl index _ userName type value nonDep kind => do
--IO.println s!"[LD] {toString type} := {toString value}"
return .ldecl index fvarId userName (← translateExpr type) (← translateExpr value) nonDep kind
partial def translateLCtx : MetaTranslateM LocalContext := do
resetFVarMap
(← getSourceLCtx).foldlM (λ lctx srcLocalDecl => do
let localDecl ← Meta.withLCtx lctx #[] do translateLocalDecl srcLocalDecl
pure $ lctx.addDecl localDecl
) (← MonadLCtx.getLCtx)
partial def translateMVarId (srcMVarId: MVarId) : MetaTranslateM MVarId := do
if let .some mvarId' := (← get).mvarMap.find? srcMVarId then
return mvarId'
let srcDecl := (← getSourceMCtx).findDecl? srcMVarId |>.get!
let mvar ← withTheReader Context (λ ctx => { ctx with sourceLCtx := srcDecl.lctx }) do
let lctx' ← translateLCtx
let localInstances' ← srcDecl.localInstances.mapM translateLocalInstance
Meta.withLCtx lctx' localInstances' do
let target' ← translateExpr srcDecl.type
Meta.mkFreshExprMVar target' srcDecl.kind srcDecl.userName
addTranslatedMVar srcMVarId mvar.mvarId!
return mvar.mvarId!
end
def translateMVarFromTermInfo (termInfo : Elab.TermInfo) (context? : Option Elab.ContextInfo)
: MetaTranslateM MVarId := do
withTheReader Context (λ ctx => { ctx with
sourceMCtx := context?.map (·.mctx) |>.getD {},
sourceLCtx := termInfo.lctx,
}) do
let type := termInfo.expectedType?.get!
let lctx' ← translateLCtx
let mvar ← Meta.withLCtx lctx' #[] do
let type' ← translateExpr type
Meta.mkFreshExprSyntheticOpaqueMVar type'
return mvar.mvarId!
def translateMVarFromTacticInfoBefore (tacticInfo : Elab.TacticInfo) (_context? : Option Elab.ContextInfo)
: MetaTranslateM (List MVarId) := do
withTheReader Context (λ ctx => { ctx with sourceMCtx := tacticInfo.mctxBefore }) do
tacticInfo.goalsBefore.mapM translateMVarId
end MetaTranslate
export MetaTranslate (MetaTranslateM)
end Pantograph.Frontend

View File

@ -26,8 +26,8 @@ theorem plus_n_Sm_proved_formal_sketch : ∀ n m : Nat, n + (m + 1) = (n + m) +
sorry
"
let goalStates ← (collectSorrysFromSource sketch).run' {}
let [goalState] := goalStates | panic! "Illegal number of states"
addTest $ LSpec.check "plus_n_Sm" ((← goalState.serializeGoals (options := {})).map (·.devolatilize) = #[
let [goalState] := goalStates | panic! "Incorrect number of states"
addTest $ LSpec.check "goals" ((← goalState.serializeGoals (options := {})).map (·.devolatilize) = #[
{
target := { pp? := "∀ (n m : Nat), n = m" },
vars := #[
@ -49,8 +49,8 @@ example : ∀ (n m: Nat), n + m = m + n := by
sorry
"
let goalStates ← (collectSorrysFromSource sketch).run' {}
let [goalState] := goalStates | panic! s!"Illegal number of states: {goalStates.length}"
addTest $ LSpec.check "plus_n_Sm" ((← goalState.serializeGoals (options := {})).map (·.devolatilize) = #[
let [goalState] := goalStates | panic! s!"Incorrect number of states: {goalStates.length}"
addTest $ LSpec.check "goals" ((← goalState.serializeGoals (options := {})).map (·.devolatilize) = #[
{
target := { pp? := "n + m = m + n" },
vars := #[{
@ -77,8 +77,8 @@ example : ∀ (n m: Nat), n + m = m + n := by
sorry
"
let goalStates ← (collectSorrysFromSource sketch).run' {}
let [goalState] := goalStates | panic! s!"Illegal number of states: {goalStates.length}"
addTest $ LSpec.check "plus_n_Sm" ((← goalState.serializeGoals (options := {})).map (·.devolatilize) = #[
let [goalState] := goalStates | panic! s!"Incorrect number of states: {goalStates.length}"
addTest $ LSpec.check "goals" ((← goalState.serializeGoals (options := {})).map (·.devolatilize) = #[
{
target := { pp? := "0 + m = m" },
vars := #[{
@ -87,6 +87,7 @@ example : ∀ (n m: Nat), n + m = m + n := by
}]
},
{
userName? := .some "zero",
target := { pp? := "0 + m = m + 0" },
vars := #[{
userName := "m",
@ -110,6 +111,7 @@ example : ∀ (n m: Nat), n + m = m + n := by
}]
},
{
userName? := .some "succ",
target := { pp? := "n + 1 + m = m + (n + 1)" },
vars := #[{
userName := "m",
@ -127,12 +129,47 @@ example : ∀ (n m: Nat), n + m = m + n := by
}
])
def test_sorry_in_coupled: TestT MetaM Unit := do
let sketch := "
example : ∀ (y: Nat), ∃ (x: Nat), y + 1 = x := by
intro y
apply Exists.intro
case h => sorry
case w => sorry
"
let goalStates ← (collectSorrysFromSource sketch).run' {}
let [goalState] := goalStates | panic! s!"Incorrect number of states: {goalStates.length}"
addTest $ LSpec.check "goals" ((← goalState.serializeGoals (options := {})).map (·.devolatilize) = #[
{
target := { pp? := "y + 1 = ?w" },
vars := #[{
userName := "y",
type? := .some { pp? := "Nat" },
}
],
},
{
userName? := .some "w",
target := { pp? := "Nat" },
vars := #[{
userName := "y✝",
isInaccessible := true,
type? := .some { pp? := "Nat" },
}, {
userName := "y",
type? := .some { pp? := "Nat" },
}
],
}
])
def suite (env : Environment): List (String × IO LSpec.TestSeq) :=
let tests := [
("multiple_sorrys_in_proof", test_multiple_sorrys_in_proof),
("sorry_in_middle", test_sorry_in_middle),
("sorry_in_induction", test_sorry_in_induction),
("sorry_in_coupled", test_sorry_in_coupled),
]
tests.map (fun (name, test) => (name, runMetaMSeq env $ runTest test))