fix: Extracting `sorry`s from coupled goals
This commit is contained in:
parent
143cd289bb
commit
530a1a1a97
|
@ -60,12 +60,13 @@ partial def instantiateDelayedMVars (eOrig: Expr) : MetaM Expr := do
|
|||
-- nested mvars.
|
||||
mvarId.setKind .syntheticOpaque
|
||||
|
||||
mvarId.withContext do
|
||||
let lctx ← MonadLCtx.getLCtx
|
||||
if mvarDecl.lctx.any (λ decl => !lctx.contains decl.fvarId) then
|
||||
let violations := mvarDecl.lctx.decls.foldl (λ acc decl? => match decl? with
|
||||
| .some decl => if lctx.contains decl.fvarId then acc else acc ++ [decl.fvarId.name]
|
||||
| .none => acc) []
|
||||
panic! s!"Local context variable violation: {violations}"
|
||||
panic! s!"In the context of {mvarId.name}, there are local context variable violations: {violations}"
|
||||
|
||||
if let .some assign ← getExprMVarAssignment? mvarId then
|
||||
--IO.println s!"{padding}├A ?{mvarId.name}"
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
/- Adapted from lean-training-data by semorrison -/
|
||||
import Pantograph.Protocol
|
||||
import Pantograph.Frontend.Basic
|
||||
import Pantograph.Frontend.Elab
|
||||
import Pantograph.Frontend.MetaTranslate
|
||||
|
|
|
@ -3,9 +3,10 @@ import Lean.Elab.Import
|
|||
import Lean.Elab.Command
|
||||
import Lean.Elab.InfoTree
|
||||
|
||||
import Pantograph.Protocol
|
||||
import Pantograph.Frontend.Basic
|
||||
import Pantograph.Frontend.MetaTranslate
|
||||
import Pantograph.Goal
|
||||
import Pantograph.Protocol
|
||||
|
||||
open Lean
|
||||
|
||||
|
@ -179,117 +180,6 @@ def collectSorrys (step: CompilationStep) : List InfoWithContext :=
|
|||
step.trees.bind collectSorrysInTree
|
||||
|
||||
|
||||
namespace MetaTranslate
|
||||
|
||||
structure Context where
|
||||
sourceMCtx : MetavarContext := {}
|
||||
sourceLCtx : LocalContext := {}
|
||||
|
||||
structure State where
|
||||
-- Stores mapping from old to new mvar/fvars
|
||||
mvarMap: HashMap MVarId MVarId := {}
|
||||
fvarMap: HashMap FVarId FVarId := {}
|
||||
|
||||
/-
|
||||
Monadic state for translating a frozen meta state. The underlying `MetaM`
|
||||
operates in the "target" context and state.
|
||||
-/
|
||||
abbrev MetaTranslateM := ReaderT Context StateRefT State MetaM
|
||||
|
||||
def getSourceLCtx : MetaTranslateM LocalContext := do pure (← read).sourceLCtx
|
||||
def getSourceMCtx : MetaTranslateM MetavarContext := do pure (← read).sourceMCtx
|
||||
def addTranslatedFVar (src dst: FVarId) : MetaTranslateM Unit := do
|
||||
let state ← get
|
||||
set { state with fvarMap := state.fvarMap.insert src dst }
|
||||
def addTranslatedMVar (src dst: MVarId) : MetaTranslateM Unit := do
|
||||
let state ← get
|
||||
set { state with mvarMap := state.mvarMap.insert src dst }
|
||||
|
||||
def resetFVarMap : MetaTranslateM Unit := do
|
||||
let state ← get
|
||||
set { state with fvarMap := {} }
|
||||
|
||||
private partial def translateExpr (srcExpr: Expr) : MetaTranslateM Expr := do
|
||||
let (srcExpr, _) := instantiateMVarsCore (mctx := ← getSourceMCtx) srcExpr
|
||||
--IO.println s!"Transform src: {srcExpr}"
|
||||
let result ← Core.transform srcExpr λ e => do
|
||||
let state ← get
|
||||
match e with
|
||||
| .fvar fvarId =>
|
||||
let .some fvarId' := state.fvarMap.find? fvarId | panic! s!"FVar id not registered: {fvarId.name}"
|
||||
return .done $ .fvar fvarId'
|
||||
| .mvar mvarId => do
|
||||
match state.mvarMap.find? mvarId with
|
||||
| .some mvarId' => do
|
||||
return .done $ .mvar mvarId'
|
||||
| .none => do
|
||||
--let t := (← getSourceMCtx).findDecl? mvarId |>.get!.type
|
||||
--let t' ← translateExpr t
|
||||
let mvar' ← Meta.mkFreshExprMVar .none
|
||||
addTranslatedMVar mvarId mvar'.mvarId!
|
||||
return .done mvar'
|
||||
| _ => return .continue
|
||||
try
|
||||
Meta.check result
|
||||
catch ex =>
|
||||
panic! s!"Check failed: {← ex.toMessageData.toString}"
|
||||
return result
|
||||
|
||||
def translateLocalDecl (srcLocalDecl: LocalDecl) : MetaTranslateM LocalDecl := do
|
||||
let fvarId ← mkFreshFVarId
|
||||
addTranslatedFVar srcLocalDecl.fvarId fvarId
|
||||
match srcLocalDecl with
|
||||
| .cdecl index _ userName type bi kind => do
|
||||
--IO.println s!"[CD] {userName} {toString type}"
|
||||
return .cdecl index fvarId userName (← translateExpr type) bi kind
|
||||
| .ldecl index _ userName type value nonDep kind => do
|
||||
--IO.println s!"[LD] {toString type} := {toString value}"
|
||||
return .ldecl index fvarId userName (← translateExpr type) (← translateExpr value) nonDep kind
|
||||
|
||||
def translateLCtx : MetaTranslateM LocalContext := do
|
||||
resetFVarMap
|
||||
(← getSourceLCtx).foldlM (λ lctx srcLocalDecl => do
|
||||
let localDecl ← Meta.withLCtx lctx #[] do translateLocalDecl srcLocalDecl
|
||||
pure $ lctx.addDecl localDecl
|
||||
) (← MonadLCtx.getLCtx)
|
||||
|
||||
|
||||
def translateMVarId (srcMVarId: MVarId) : MetaTranslateM MVarId := do
|
||||
let srcDecl := (← getSourceMCtx).findDecl? srcMVarId |>.get!
|
||||
let mvar ← withTheReader Context (λ ctx => { ctx with sourceLCtx := srcDecl.lctx }) do
|
||||
let lctx' ← translateLCtx
|
||||
Meta.withLCtx lctx' #[] do
|
||||
let target' ← translateExpr srcDecl.type
|
||||
Meta.mkFreshExprSyntheticOpaqueMVar target'
|
||||
addTranslatedMVar srcMVarId mvar.mvarId!
|
||||
return mvar.mvarId!
|
||||
|
||||
def translateMVarFromTermInfo (termInfo : Elab.TermInfo) (context? : Option Elab.ContextInfo)
|
||||
: MetaM MVarId := do
|
||||
let trM : MetaTranslateM MVarId := do
|
||||
let type := termInfo.expectedType?.get!
|
||||
let lctx' ← translateLCtx
|
||||
let mvar ← Meta.withLCtx lctx' #[] do
|
||||
let type' ← translateExpr type
|
||||
Meta.mkFreshExprSyntheticOpaqueMVar type'
|
||||
return mvar.mvarId!
|
||||
trM.run {
|
||||
sourceMCtx := context?.map (·.mctx) |>.getD {},
|
||||
sourceLCtx := termInfo.lctx } |>.run' {}
|
||||
|
||||
|
||||
def translateMVarFromTacticInfoBefore (tacticInfo : Elab.TacticInfo) (_context? : Option Elab.ContextInfo)
|
||||
: MetaM (List MVarId) := do
|
||||
let trM : MetaTranslateM (List MVarId) := do
|
||||
tacticInfo.goalsBefore.mapM translateMVarId
|
||||
trM.run {
|
||||
sourceMCtx := tacticInfo.mctxBefore
|
||||
} |>.run' {}
|
||||
|
||||
|
||||
end MetaTranslate
|
||||
|
||||
export MetaTranslate (MetaTranslateM)
|
||||
|
||||
/--
|
||||
Since we cannot directly merge `MetavarContext`s, we have to get creative. This
|
||||
|
@ -299,7 +189,7 @@ the current `MetavarContext`.
|
|||
@[export pantograph_frontend_sorrys_to_goal_state]
|
||||
def sorrysToGoalState (sorrys : List InfoWithContext) : MetaM GoalState := do
|
||||
assert! !sorrys.isEmpty
|
||||
let goals ← sorrys.mapM λ i => do
|
||||
let goalsM := sorrys.mapM λ i => do
|
||||
match i.info with
|
||||
| .ofTermInfo termInfo => do
|
||||
let mvarId ← MetaTranslate.translateMVarFromTermInfo termInfo i.context?
|
||||
|
@ -307,7 +197,7 @@ def sorrysToGoalState (sorrys : List InfoWithContext) : MetaM GoalState := do
|
|||
| .ofTacticInfo tacticInfo => do
|
||||
MetaTranslate.translateMVarFromTacticInfoBefore tacticInfo i.context?
|
||||
| _ => panic! "Invalid info"
|
||||
let goals := goals.bind id
|
||||
let goals := (← goalsM.run {} |>.run' {}).bind id
|
||||
let root := match goals with
|
||||
| [] => panic! "This function cannot be called on an empty list"
|
||||
| [g] => g
|
||||
|
|
|
@ -0,0 +1,129 @@
|
|||
import Lean.Meta
|
||||
|
||||
open Lean
|
||||
|
||||
namespace Pantograph.Frontend
|
||||
|
||||
namespace MetaTranslate
|
||||
|
||||
structure Context where
|
||||
sourceMCtx : MetavarContext := {}
|
||||
sourceLCtx : LocalContext := {}
|
||||
|
||||
abbrev FVarMap := HashMap FVarId FVarId
|
||||
|
||||
structure State where
|
||||
-- Stores mapping from old to new mvar/fvars
|
||||
mvarMap: HashMap MVarId MVarId := {}
|
||||
fvarMap: HashMap FVarId FVarId := {}
|
||||
|
||||
/-
|
||||
Monadic state for translating a frozen meta state. The underlying `MetaM`
|
||||
operates in the "target" context and state.
|
||||
-/
|
||||
abbrev MetaTranslateM := ReaderT Context StateRefT State MetaM
|
||||
|
||||
def getSourceLCtx : MetaTranslateM LocalContext := do pure (← read).sourceLCtx
|
||||
def getSourceMCtx : MetaTranslateM MetavarContext := do pure (← read).sourceMCtx
|
||||
def addTranslatedFVar (src dst: FVarId) : MetaTranslateM Unit := do
|
||||
modifyGet λ state => ((), { state with fvarMap := state.fvarMap.insert src dst })
|
||||
def addTranslatedMVar (src dst: MVarId) : MetaTranslateM Unit := do
|
||||
modifyGet λ state => ((), { state with mvarMap := state.mvarMap.insert src dst })
|
||||
|
||||
def saveFVarMap : MetaTranslateM FVarMap := do
|
||||
return (← get).fvarMap
|
||||
def restoreFVarMap (map: FVarMap) : MetaTranslateM Unit := do
|
||||
modifyGet λ state => ((), { state with fvarMap := map })
|
||||
def resetFVarMap : MetaTranslateM Unit := do
|
||||
modifyGet λ state => ((), { state with fvarMap := {} })
|
||||
|
||||
mutual
|
||||
private partial def translateExpr (srcExpr: Expr) : MetaTranslateM Expr := do
|
||||
let sourceMCtx ← getSourceMCtx
|
||||
let (srcExpr, _) := instantiateMVarsCore (mctx := sourceMCtx) srcExpr
|
||||
--IO.println s!"Transform src: {srcExpr}"
|
||||
let result ← Core.transform srcExpr λ e => do
|
||||
let state ← get
|
||||
match e with
|
||||
| .fvar fvarId =>
|
||||
let .some fvarId' := state.fvarMap.find? fvarId | panic! s!"FVar id not registered: {fvarId.name}"
|
||||
assert! (← getLCtx).contains fvarId'
|
||||
return .done $ .fvar fvarId'
|
||||
| .mvar mvarId => do
|
||||
assert! !(sourceMCtx.dAssignment.contains mvarId)
|
||||
assert! !(sourceMCtx.eAssignment.contains mvarId)
|
||||
match state.mvarMap.find? mvarId with
|
||||
| .some mvarId' => do
|
||||
return .done $ .mvar mvarId'
|
||||
| .none => do
|
||||
-- Entering another LCtx, must save the current one
|
||||
let fvarMap ← saveFVarMap
|
||||
let mvarId' ← translateMVarId mvarId
|
||||
restoreFVarMap fvarMap
|
||||
return .done $ .mvar mvarId'
|
||||
| _ => return .continue
|
||||
Meta.check result
|
||||
return result
|
||||
|
||||
partial def translateLocalInstance (srcInstance: LocalInstance) : MetaTranslateM LocalInstance := do
|
||||
return {
|
||||
className := srcInstance.className,
|
||||
fvar := ← translateExpr srcInstance.fvar
|
||||
}
|
||||
partial def translateLocalDecl (srcLocalDecl: LocalDecl) : MetaTranslateM LocalDecl := do
|
||||
let fvarId ← mkFreshFVarId
|
||||
addTranslatedFVar srcLocalDecl.fvarId fvarId
|
||||
match srcLocalDecl with
|
||||
| .cdecl index _ userName type bi kind => do
|
||||
--IO.println s!"[CD] {userName} {toString type}"
|
||||
return .cdecl index fvarId userName (← translateExpr type) bi kind
|
||||
| .ldecl index _ userName type value nonDep kind => do
|
||||
--IO.println s!"[LD] {toString type} := {toString value}"
|
||||
return .ldecl index fvarId userName (← translateExpr type) (← translateExpr value) nonDep kind
|
||||
|
||||
partial def translateLCtx : MetaTranslateM LocalContext := do
|
||||
resetFVarMap
|
||||
(← getSourceLCtx).foldlM (λ lctx srcLocalDecl => do
|
||||
let localDecl ← Meta.withLCtx lctx #[] do translateLocalDecl srcLocalDecl
|
||||
pure $ lctx.addDecl localDecl
|
||||
) (← MonadLCtx.getLCtx)
|
||||
|
||||
partial def translateMVarId (srcMVarId: MVarId) : MetaTranslateM MVarId := do
|
||||
if let .some mvarId' := (← get).mvarMap.find? srcMVarId then
|
||||
return mvarId'
|
||||
let srcDecl := (← getSourceMCtx).findDecl? srcMVarId |>.get!
|
||||
let mvar ← withTheReader Context (λ ctx => { ctx with sourceLCtx := srcDecl.lctx }) do
|
||||
let lctx' ← translateLCtx
|
||||
let localInstances' ← srcDecl.localInstances.mapM translateLocalInstance
|
||||
Meta.withLCtx lctx' localInstances' do
|
||||
let target' ← translateExpr srcDecl.type
|
||||
Meta.mkFreshExprMVar target' srcDecl.kind srcDecl.userName
|
||||
addTranslatedMVar srcMVarId mvar.mvarId!
|
||||
return mvar.mvarId!
|
||||
end
|
||||
|
||||
def translateMVarFromTermInfo (termInfo : Elab.TermInfo) (context? : Option Elab.ContextInfo)
|
||||
: MetaTranslateM MVarId := do
|
||||
withTheReader Context (λ ctx => { ctx with
|
||||
sourceMCtx := context?.map (·.mctx) |>.getD {},
|
||||
sourceLCtx := termInfo.lctx,
|
||||
}) do
|
||||
let type := termInfo.expectedType?.get!
|
||||
let lctx' ← translateLCtx
|
||||
let mvar ← Meta.withLCtx lctx' #[] do
|
||||
let type' ← translateExpr type
|
||||
Meta.mkFreshExprSyntheticOpaqueMVar type'
|
||||
return mvar.mvarId!
|
||||
|
||||
|
||||
def translateMVarFromTacticInfoBefore (tacticInfo : Elab.TacticInfo) (_context? : Option Elab.ContextInfo)
|
||||
: MetaTranslateM (List MVarId) := do
|
||||
withTheReader Context (λ ctx => { ctx with sourceMCtx := tacticInfo.mctxBefore }) do
|
||||
tacticInfo.goalsBefore.mapM translateMVarId
|
||||
|
||||
|
||||
end MetaTranslate
|
||||
|
||||
export MetaTranslate (MetaTranslateM)
|
||||
|
||||
end Pantograph.Frontend
|
|
@ -26,8 +26,8 @@ theorem plus_n_Sm_proved_formal_sketch : ∀ n m : Nat, n + (m + 1) = (n + m) +
|
|||
sorry
|
||||
"
|
||||
let goalStates ← (collectSorrysFromSource sketch).run' {}
|
||||
let [goalState] := goalStates | panic! "Illegal number of states"
|
||||
addTest $ LSpec.check "plus_n_Sm" ((← goalState.serializeGoals (options := {})).map (·.devolatilize) = #[
|
||||
let [goalState] := goalStates | panic! "Incorrect number of states"
|
||||
addTest $ LSpec.check "goals" ((← goalState.serializeGoals (options := {})).map (·.devolatilize) = #[
|
||||
{
|
||||
target := { pp? := "∀ (n m : Nat), n = m" },
|
||||
vars := #[
|
||||
|
@ -49,8 +49,8 @@ example : ∀ (n m: Nat), n + m = m + n := by
|
|||
sorry
|
||||
"
|
||||
let goalStates ← (collectSorrysFromSource sketch).run' {}
|
||||
let [goalState] := goalStates | panic! s!"Illegal number of states: {goalStates.length}"
|
||||
addTest $ LSpec.check "plus_n_Sm" ((← goalState.serializeGoals (options := {})).map (·.devolatilize) = #[
|
||||
let [goalState] := goalStates | panic! s!"Incorrect number of states: {goalStates.length}"
|
||||
addTest $ LSpec.check "goals" ((← goalState.serializeGoals (options := {})).map (·.devolatilize) = #[
|
||||
{
|
||||
target := { pp? := "n + m = m + n" },
|
||||
vars := #[{
|
||||
|
@ -77,8 +77,8 @@ example : ∀ (n m: Nat), n + m = m + n := by
|
|||
sorry
|
||||
"
|
||||
let goalStates ← (collectSorrysFromSource sketch).run' {}
|
||||
let [goalState] := goalStates | panic! s!"Illegal number of states: {goalStates.length}"
|
||||
addTest $ LSpec.check "plus_n_Sm" ((← goalState.serializeGoals (options := {})).map (·.devolatilize) = #[
|
||||
let [goalState] := goalStates | panic! s!"Incorrect number of states: {goalStates.length}"
|
||||
addTest $ LSpec.check "goals" ((← goalState.serializeGoals (options := {})).map (·.devolatilize) = #[
|
||||
{
|
||||
target := { pp? := "0 + m = m" },
|
||||
vars := #[{
|
||||
|
@ -87,6 +87,7 @@ example : ∀ (n m: Nat), n + m = m + n := by
|
|||
}]
|
||||
},
|
||||
{
|
||||
userName? := .some "zero",
|
||||
target := { pp? := "0 + m = m + 0" },
|
||||
vars := #[{
|
||||
userName := "m",
|
||||
|
@ -110,6 +111,7 @@ example : ∀ (n m: Nat), n + m = m + n := by
|
|||
}]
|
||||
},
|
||||
{
|
||||
userName? := .some "succ",
|
||||
target := { pp? := "n + 1 + m = m + (n + 1)" },
|
||||
vars := #[{
|
||||
userName := "m",
|
||||
|
@ -127,12 +129,47 @@ example : ∀ (n m: Nat), n + m = m + n := by
|
|||
}
|
||||
])
|
||||
|
||||
def test_sorry_in_coupled: TestT MetaM Unit := do
|
||||
let sketch := "
|
||||
example : ∀ (y: Nat), ∃ (x: Nat), y + 1 = x := by
|
||||
intro y
|
||||
apply Exists.intro
|
||||
case h => sorry
|
||||
case w => sorry
|
||||
"
|
||||
let goalStates ← (collectSorrysFromSource sketch).run' {}
|
||||
let [goalState] := goalStates | panic! s!"Incorrect number of states: {goalStates.length}"
|
||||
addTest $ LSpec.check "goals" ((← goalState.serializeGoals (options := {})).map (·.devolatilize) = #[
|
||||
{
|
||||
target := { pp? := "y + 1 = ?w" },
|
||||
vars := #[{
|
||||
userName := "y",
|
||||
type? := .some { pp? := "Nat" },
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
userName? := .some "w",
|
||||
target := { pp? := "Nat" },
|
||||
vars := #[{
|
||||
userName := "y✝",
|
||||
isInaccessible := true,
|
||||
type? := .some { pp? := "Nat" },
|
||||
}, {
|
||||
userName := "y",
|
||||
type? := .some { pp? := "Nat" },
|
||||
}
|
||||
],
|
||||
}
|
||||
])
|
||||
|
||||
|
||||
def suite (env : Environment): List (String × IO LSpec.TestSeq) :=
|
||||
let tests := [
|
||||
("multiple_sorrys_in_proof", test_multiple_sorrys_in_proof),
|
||||
("sorry_in_middle", test_sorry_in_middle),
|
||||
("sorry_in_induction", test_sorry_in_induction),
|
||||
("sorry_in_coupled", test_sorry_in_coupled),
|
||||
]
|
||||
tests.map (fun (name, test) => (name, runMetaMSeq env $ runTest test))
|
||||
|
||||
|
|
Loading…
Reference in New Issue