fix: Extracting `sorry`s from coupled goals

This commit is contained in:
Leni Aniva 2024-10-03 11:35:54 -07:00
parent 143cd289bb
commit 530a1a1a97
Signed by: aniva
GPG Key ID: 4D9B1C8D10EA4C50
5 changed files with 221 additions and 164 deletions

View File

@ -60,53 +60,54 @@ partial def instantiateDelayedMVars (eOrig: Expr) : MetaM Expr := do
-- nested mvars. -- nested mvars.
mvarId.setKind .syntheticOpaque mvarId.setKind .syntheticOpaque
let lctx ← MonadLCtx.getLCtx mvarId.withContext do
if mvarDecl.lctx.any (λ decl => !lctx.contains decl.fvarId) then let lctx ← MonadLCtx.getLCtx
let violations := mvarDecl.lctx.decls.foldl (λ acc decl? => match decl? with if mvarDecl.lctx.any (λ decl => !lctx.contains decl.fvarId) then
| .some decl => if lctx.contains decl.fvarId then acc else acc ++ [decl.fvarId.name] let violations := mvarDecl.lctx.decls.foldl (λ acc decl? => match decl? with
| .none => acc) [] | .some decl => if lctx.contains decl.fvarId then acc else acc ++ [decl.fvarId.name]
panic! s!"Local context variable violation: {violations}" | .none => acc) []
panic! s!"In the context of {mvarId.name}, there are local context variable violations: {violations}"
if let .some assign ← getExprMVarAssignment? mvarId then if let .some assign ← getExprMVarAssignment? mvarId then
--IO.println s!"{padding}├A ?{mvarId.name}" --IO.println s!"{padding}├A ?{mvarId.name}"
assert! !(← mvarId.isDelayedAssigned) assert! !(← mvarId.isDelayedAssigned)
return .visit (mkAppN assign args) return .visit (mkAppN assign args)
else if let some { fvars, mvarIdPending } ← getDelayedMVarAssignment? mvarId then else if let some { fvars, mvarIdPending } ← getDelayedMVarAssignment? mvarId then
--let substTableStr := String.intercalate ", " $ Array.zipWith fvars args (λ fvar assign => s!"{fvar.fvarId!.name} := {assign}") |>.toList --let substTableStr := String.intercalate ", " $ Array.zipWith fvars args (λ fvar assign => s!"{fvar.fvarId!.name} := {assign}") |>.toList
--IO.println s!"{padding}├MD ?{mvarId.name} := ?{mvarIdPending.name} [{substTableStr}]" --IO.println s!"{padding}├MD ?{mvarId.name} := ?{mvarIdPending.name} [{substTableStr}]"
if args.size < fvars.size then if args.size < fvars.size then
throwError "Not enough arguments to instantiate a delay assigned mvar. This is due to bad implementations of a tactic: {args.size} < {fvars.size}. Expr: {toString e}; Origin: {toString eOrig}" throwError "Not enough arguments to instantiate a delay assigned mvar. This is due to bad implementations of a tactic: {args.size} < {fvars.size}. Expr: {toString e}; Origin: {toString eOrig}"
--if !args.isEmpty then --if !args.isEmpty then
--IO.println s!"{padding}├── Arguments Begin" --IO.println s!"{padding}├── Arguments Begin"
let args ← args.mapM self let args ← args.mapM self
--if !args.isEmpty then --if !args.isEmpty then
--IO.println s!"{padding}├── Arguments End" --IO.println s!"{padding}├── Arguments End"
if !(← mvarIdPending.isAssignedOrDelayedAssigned) then if !(← mvarIdPending.isAssignedOrDelayedAssigned) then
--IO.println s!"{padding}├T1" --IO.println s!"{padding}├T1"
let result := mkAppN f args let result := mkAppN f args
return .done result
let pending ← mvarIdPending.withContext do
let inner ← instantiateDelayedMVars (.mvar mvarIdPending) --(level := level + 1)
--IO.println s!"{padding}├Pre: {inner}"
pure <| (← inner.abstractM fvars).instantiateRev args
-- Tail arguments
let result := mkAppRange pending fvars.size args.size args
--IO.println s!"{padding}├MD {result}"
return .done result return .done result
else
assert! !(← mvarId.isAssigned)
assert! !(← mvarId.isDelayedAssigned)
--if !args.isEmpty then
-- IO.println s!"{padding}├── Arguments Begin"
let args ← args.mapM self
--if !args.isEmpty then
-- IO.println s!"{padding}├── Arguments End"
let pending ← mvarIdPending.withContext do --IO.println s!"{padding}├M ?{mvarId.name}"
let inner ← instantiateDelayedMVars (.mvar mvarIdPending) --(level := level + 1) return .done (mkAppN f args))
--IO.println s!"{padding}├Pre: {inner}"
pure <| (← inner.abstractM fvars).instantiateRev args
-- Tail arguments
let result := mkAppRange pending fvars.size args.size args
--IO.println s!"{padding}├MD {result}"
return .done result
else
assert! !(← mvarId.isAssigned)
assert! !(← mvarId.isDelayedAssigned)
--if !args.isEmpty then
-- IO.println s!"{padding}├── Arguments Begin"
let args ← args.mapM self
--if !args.isEmpty then
-- IO.println s!"{padding}├── Arguments End"
--IO.println s!"{padding}├M ?{mvarId.name}"
return .done (mkAppN f args))
--IO.println s!"{padding}└Result {result}" --IO.println s!"{padding}└Result {result}"
return result return result
where where

View File

@ -1,4 +1,4 @@
/- Adapted from lean-training-data by semorrison -/ /- Adapted from lean-training-data by semorrison -/
import Pantograph.Protocol
import Pantograph.Frontend.Basic import Pantograph.Frontend.Basic
import Pantograph.Frontend.Elab import Pantograph.Frontend.Elab
import Pantograph.Frontend.MetaTranslate

View File

@ -3,9 +3,10 @@ import Lean.Elab.Import
import Lean.Elab.Command import Lean.Elab.Command
import Lean.Elab.InfoTree import Lean.Elab.InfoTree
import Pantograph.Protocol
import Pantograph.Frontend.Basic import Pantograph.Frontend.Basic
import Pantograph.Frontend.MetaTranslate
import Pantograph.Goal import Pantograph.Goal
import Pantograph.Protocol
open Lean open Lean
@ -179,117 +180,6 @@ def collectSorrys (step: CompilationStep) : List InfoWithContext :=
step.trees.bind collectSorrysInTree step.trees.bind collectSorrysInTree
namespace MetaTranslate
structure Context where
sourceMCtx : MetavarContext := {}
sourceLCtx : LocalContext := {}
structure State where
-- Stores mapping from old to new mvar/fvars
mvarMap: HashMap MVarId MVarId := {}
fvarMap: HashMap FVarId FVarId := {}
/-
Monadic state for translating a frozen meta state. The underlying `MetaM`
operates in the "target" context and state.
-/
abbrev MetaTranslateM := ReaderT Context StateRefT State MetaM
def getSourceLCtx : MetaTranslateM LocalContext := do pure (← read).sourceLCtx
def getSourceMCtx : MetaTranslateM MetavarContext := do pure (← read).sourceMCtx
def addTranslatedFVar (src dst: FVarId) : MetaTranslateM Unit := do
let state ← get
set { state with fvarMap := state.fvarMap.insert src dst }
def addTranslatedMVar (src dst: MVarId) : MetaTranslateM Unit := do
let state ← get
set { state with mvarMap := state.mvarMap.insert src dst }
def resetFVarMap : MetaTranslateM Unit := do
let state ← get
set { state with fvarMap := {} }
private partial def translateExpr (srcExpr: Expr) : MetaTranslateM Expr := do
let (srcExpr, _) := instantiateMVarsCore (mctx := ← getSourceMCtx) srcExpr
--IO.println s!"Transform src: {srcExpr}"
let result ← Core.transform srcExpr λ e => do
let state ← get
match e with
| .fvar fvarId =>
let .some fvarId' := state.fvarMap.find? fvarId | panic! s!"FVar id not registered: {fvarId.name}"
return .done $ .fvar fvarId'
| .mvar mvarId => do
match state.mvarMap.find? mvarId with
| .some mvarId' => do
return .done $ .mvar mvarId'
| .none => do
--let t := (← getSourceMCtx).findDecl? mvarId |>.get!.type
--let t' ← translateExpr t
let mvar' ← Meta.mkFreshExprMVar .none
addTranslatedMVar mvarId mvar'.mvarId!
return .done mvar'
| _ => return .continue
try
Meta.check result
catch ex =>
panic! s!"Check failed: {← ex.toMessageData.toString}"
return result
def translateLocalDecl (srcLocalDecl: LocalDecl) : MetaTranslateM LocalDecl := do
let fvarId ← mkFreshFVarId
addTranslatedFVar srcLocalDecl.fvarId fvarId
match srcLocalDecl with
| .cdecl index _ userName type bi kind => do
--IO.println s!"[CD] {userName} {toString type}"
return .cdecl index fvarId userName (← translateExpr type) bi kind
| .ldecl index _ userName type value nonDep kind => do
--IO.println s!"[LD] {toString type} := {toString value}"
return .ldecl index fvarId userName (← translateExpr type) (← translateExpr value) nonDep kind
def translateLCtx : MetaTranslateM LocalContext := do
resetFVarMap
(← getSourceLCtx).foldlM (λ lctx srcLocalDecl => do
let localDecl ← Meta.withLCtx lctx #[] do translateLocalDecl srcLocalDecl
pure $ lctx.addDecl localDecl
) (← MonadLCtx.getLCtx)
def translateMVarId (srcMVarId: MVarId) : MetaTranslateM MVarId := do
let srcDecl := (← getSourceMCtx).findDecl? srcMVarId |>.get!
let mvar ← withTheReader Context (λ ctx => { ctx with sourceLCtx := srcDecl.lctx }) do
let lctx' ← translateLCtx
Meta.withLCtx lctx' #[] do
let target' ← translateExpr srcDecl.type
Meta.mkFreshExprSyntheticOpaqueMVar target'
addTranslatedMVar srcMVarId mvar.mvarId!
return mvar.mvarId!
def translateMVarFromTermInfo (termInfo : Elab.TermInfo) (context? : Option Elab.ContextInfo)
: MetaM MVarId := do
let trM : MetaTranslateM MVarId := do
let type := termInfo.expectedType?.get!
let lctx' ← translateLCtx
let mvar ← Meta.withLCtx lctx' #[] do
let type' ← translateExpr type
Meta.mkFreshExprSyntheticOpaqueMVar type'
return mvar.mvarId!
trM.run {
sourceMCtx := context?.map (·.mctx) |>.getD {},
sourceLCtx := termInfo.lctx } |>.run' {}
def translateMVarFromTacticInfoBefore (tacticInfo : Elab.TacticInfo) (_context? : Option Elab.ContextInfo)
: MetaM (List MVarId) := do
let trM : MetaTranslateM (List MVarId) := do
tacticInfo.goalsBefore.mapM translateMVarId
trM.run {
sourceMCtx := tacticInfo.mctxBefore
} |>.run' {}
end MetaTranslate
export MetaTranslate (MetaTranslateM)
/-- /--
Since we cannot directly merge `MetavarContext`s, we have to get creative. This Since we cannot directly merge `MetavarContext`s, we have to get creative. This
@ -299,7 +189,7 @@ the current `MetavarContext`.
@[export pantograph_frontend_sorrys_to_goal_state] @[export pantograph_frontend_sorrys_to_goal_state]
def sorrysToGoalState (sorrys : List InfoWithContext) : MetaM GoalState := do def sorrysToGoalState (sorrys : List InfoWithContext) : MetaM GoalState := do
assert! !sorrys.isEmpty assert! !sorrys.isEmpty
let goals sorrys.mapM λ i => do let goalsM := sorrys.mapM λ i => do
match i.info with match i.info with
| .ofTermInfo termInfo => do | .ofTermInfo termInfo => do
let mvarId ← MetaTranslate.translateMVarFromTermInfo termInfo i.context? let mvarId ← MetaTranslate.translateMVarFromTermInfo termInfo i.context?
@ -307,7 +197,7 @@ def sorrysToGoalState (sorrys : List InfoWithContext) : MetaM GoalState := do
| .ofTacticInfo tacticInfo => do | .ofTacticInfo tacticInfo => do
MetaTranslate.translateMVarFromTacticInfoBefore tacticInfo i.context? MetaTranslate.translateMVarFromTacticInfoBefore tacticInfo i.context?
| _ => panic! "Invalid info" | _ => panic! "Invalid info"
let goals := goals.bind id let goals := (← goalsM.run {} |>.run' {}).bind id
let root := match goals with let root := match goals with
| [] => panic! "This function cannot be called on an empty list" | [] => panic! "This function cannot be called on an empty list"
| [g] => g | [g] => g

View File

@ -0,0 +1,129 @@
import Lean.Meta
open Lean
namespace Pantograph.Frontend
namespace MetaTranslate
structure Context where
sourceMCtx : MetavarContext := {}
sourceLCtx : LocalContext := {}
abbrev FVarMap := HashMap FVarId FVarId
structure State where
-- Stores mapping from old to new mvar/fvars
mvarMap: HashMap MVarId MVarId := {}
fvarMap: HashMap FVarId FVarId := {}
/-
Monadic state for translating a frozen meta state. The underlying `MetaM`
operates in the "target" context and state.
-/
abbrev MetaTranslateM := ReaderT Context StateRefT State MetaM
def getSourceLCtx : MetaTranslateM LocalContext := do pure (← read).sourceLCtx
def getSourceMCtx : MetaTranslateM MetavarContext := do pure (← read).sourceMCtx
def addTranslatedFVar (src dst: FVarId) : MetaTranslateM Unit := do
modifyGet λ state => ((), { state with fvarMap := state.fvarMap.insert src dst })
def addTranslatedMVar (src dst: MVarId) : MetaTranslateM Unit := do
modifyGet λ state => ((), { state with mvarMap := state.mvarMap.insert src dst })
def saveFVarMap : MetaTranslateM FVarMap := do
return (← get).fvarMap
def restoreFVarMap (map: FVarMap) : MetaTranslateM Unit := do
modifyGet λ state => ((), { state with fvarMap := map })
def resetFVarMap : MetaTranslateM Unit := do
modifyGet λ state => ((), { state with fvarMap := {} })
mutual
private partial def translateExpr (srcExpr: Expr) : MetaTranslateM Expr := do
let sourceMCtx ← getSourceMCtx
let (srcExpr, _) := instantiateMVarsCore (mctx := sourceMCtx) srcExpr
--IO.println s!"Transform src: {srcExpr}"
let result ← Core.transform srcExpr λ e => do
let state ← get
match e with
| .fvar fvarId =>
let .some fvarId' := state.fvarMap.find? fvarId | panic! s!"FVar id not registered: {fvarId.name}"
assert! (← getLCtx).contains fvarId'
return .done $ .fvar fvarId'
| .mvar mvarId => do
assert! !(sourceMCtx.dAssignment.contains mvarId)
assert! !(sourceMCtx.eAssignment.contains mvarId)
match state.mvarMap.find? mvarId with
| .some mvarId' => do
return .done $ .mvar mvarId'
| .none => do
-- Entering another LCtx, must save the current one
let fvarMap ← saveFVarMap
let mvarId' ← translateMVarId mvarId
restoreFVarMap fvarMap
return .done $ .mvar mvarId'
| _ => return .continue
Meta.check result
return result
partial def translateLocalInstance (srcInstance: LocalInstance) : MetaTranslateM LocalInstance := do
return {
className := srcInstance.className,
fvar := ← translateExpr srcInstance.fvar
}
partial def translateLocalDecl (srcLocalDecl: LocalDecl) : MetaTranslateM LocalDecl := do
let fvarId ← mkFreshFVarId
addTranslatedFVar srcLocalDecl.fvarId fvarId
match srcLocalDecl with
| .cdecl index _ userName type bi kind => do
--IO.println s!"[CD] {userName} {toString type}"
return .cdecl index fvarId userName (← translateExpr type) bi kind
| .ldecl index _ userName type value nonDep kind => do
--IO.println s!"[LD] {toString type} := {toString value}"
return .ldecl index fvarId userName (← translateExpr type) (← translateExpr value) nonDep kind
partial def translateLCtx : MetaTranslateM LocalContext := do
resetFVarMap
(← getSourceLCtx).foldlM (λ lctx srcLocalDecl => do
let localDecl ← Meta.withLCtx lctx #[] do translateLocalDecl srcLocalDecl
pure $ lctx.addDecl localDecl
) (← MonadLCtx.getLCtx)
partial def translateMVarId (srcMVarId: MVarId) : MetaTranslateM MVarId := do
if let .some mvarId' := (← get).mvarMap.find? srcMVarId then
return mvarId'
let srcDecl := (← getSourceMCtx).findDecl? srcMVarId |>.get!
let mvar ← withTheReader Context (λ ctx => { ctx with sourceLCtx := srcDecl.lctx }) do
let lctx' ← translateLCtx
let localInstances' ← srcDecl.localInstances.mapM translateLocalInstance
Meta.withLCtx lctx' localInstances' do
let target' ← translateExpr srcDecl.type
Meta.mkFreshExprMVar target' srcDecl.kind srcDecl.userName
addTranslatedMVar srcMVarId mvar.mvarId!
return mvar.mvarId!
end
def translateMVarFromTermInfo (termInfo : Elab.TermInfo) (context? : Option Elab.ContextInfo)
: MetaTranslateM MVarId := do
withTheReader Context (λ ctx => { ctx with
sourceMCtx := context?.map (·.mctx) |>.getD {},
sourceLCtx := termInfo.lctx,
}) do
let type := termInfo.expectedType?.get!
let lctx' ← translateLCtx
let mvar ← Meta.withLCtx lctx' #[] do
let type' ← translateExpr type
Meta.mkFreshExprSyntheticOpaqueMVar type'
return mvar.mvarId!
def translateMVarFromTacticInfoBefore (tacticInfo : Elab.TacticInfo) (_context? : Option Elab.ContextInfo)
: MetaTranslateM (List MVarId) := do
withTheReader Context (λ ctx => { ctx with sourceMCtx := tacticInfo.mctxBefore }) do
tacticInfo.goalsBefore.mapM translateMVarId
end MetaTranslate
export MetaTranslate (MetaTranslateM)
end Pantograph.Frontend

View File

@ -26,8 +26,8 @@ theorem plus_n_Sm_proved_formal_sketch : ∀ n m : Nat, n + (m + 1) = (n + m) +
sorry sorry
" "
let goalStates ← (collectSorrysFromSource sketch).run' {} let goalStates ← (collectSorrysFromSource sketch).run' {}
let [goalState] := goalStates | panic! "Illegal number of states" let [goalState] := goalStates | panic! "Incorrect number of states"
addTest $ LSpec.check "plus_n_Sm" ((← goalState.serializeGoals (options := {})).map (·.devolatilize) = #[ addTest $ LSpec.check "goals" ((← goalState.serializeGoals (options := {})).map (·.devolatilize) = #[
{ {
target := { pp? := "∀ (n m : Nat), n = m" }, target := { pp? := "∀ (n m : Nat), n = m" },
vars := #[ vars := #[
@ -49,8 +49,8 @@ example : ∀ (n m: Nat), n + m = m + n := by
sorry sorry
" "
let goalStates ← (collectSorrysFromSource sketch).run' {} let goalStates ← (collectSorrysFromSource sketch).run' {}
let [goalState] := goalStates | panic! s!"Illegal number of states: {goalStates.length}" let [goalState] := goalStates | panic! s!"Incorrect number of states: {goalStates.length}"
addTest $ LSpec.check "plus_n_Sm" ((← goalState.serializeGoals (options := {})).map (·.devolatilize) = #[ addTest $ LSpec.check "goals" ((← goalState.serializeGoals (options := {})).map (·.devolatilize) = #[
{ {
target := { pp? := "n + m = m + n" }, target := { pp? := "n + m = m + n" },
vars := #[{ vars := #[{
@ -77,8 +77,8 @@ example : ∀ (n m: Nat), n + m = m + n := by
sorry sorry
" "
let goalStates ← (collectSorrysFromSource sketch).run' {} let goalStates ← (collectSorrysFromSource sketch).run' {}
let [goalState] := goalStates | panic! s!"Illegal number of states: {goalStates.length}" let [goalState] := goalStates | panic! s!"Incorrect number of states: {goalStates.length}"
addTest $ LSpec.check "plus_n_Sm" ((← goalState.serializeGoals (options := {})).map (·.devolatilize) = #[ addTest $ LSpec.check "goals" ((← goalState.serializeGoals (options := {})).map (·.devolatilize) = #[
{ {
target := { pp? := "0 + m = m" }, target := { pp? := "0 + m = m" },
vars := #[{ vars := #[{
@ -87,6 +87,7 @@ example : ∀ (n m: Nat), n + m = m + n := by
}] }]
}, },
{ {
userName? := .some "zero",
target := { pp? := "0 + m = m + 0" }, target := { pp? := "0 + m = m + 0" },
vars := #[{ vars := #[{
userName := "m", userName := "m",
@ -110,6 +111,7 @@ example : ∀ (n m: Nat), n + m = m + n := by
}] }]
}, },
{ {
userName? := .some "succ",
target := { pp? := "n + 1 + m = m + (n + 1)" }, target := { pp? := "n + 1 + m = m + (n + 1)" },
vars := #[{ vars := #[{
userName := "m", userName := "m",
@ -127,12 +129,47 @@ example : ∀ (n m: Nat), n + m = m + n := by
} }
]) ])
def test_sorry_in_coupled: TestT MetaM Unit := do
let sketch := "
example : ∀ (y: Nat), ∃ (x: Nat), y + 1 = x := by
intro y
apply Exists.intro
case h => sorry
case w => sorry
"
let goalStates ← (collectSorrysFromSource sketch).run' {}
let [goalState] := goalStates | panic! s!"Incorrect number of states: {goalStates.length}"
addTest $ LSpec.check "goals" ((← goalState.serializeGoals (options := {})).map (·.devolatilize) = #[
{
target := { pp? := "y + 1 = ?w" },
vars := #[{
userName := "y",
type? := .some { pp? := "Nat" },
}
],
},
{
userName? := .some "w",
target := { pp? := "Nat" },
vars := #[{
userName := "y✝",
isInaccessible := true,
type? := .some { pp? := "Nat" },
}, {
userName := "y",
type? := .some { pp? := "Nat" },
}
],
}
])
def suite (env : Environment): List (String × IO LSpec.TestSeq) := def suite (env : Environment): List (String × IO LSpec.TestSeq) :=
let tests := [ let tests := [
("multiple_sorrys_in_proof", test_multiple_sorrys_in_proof), ("multiple_sorrys_in_proof", test_multiple_sorrys_in_proof),
("sorry_in_middle", test_sorry_in_middle), ("sorry_in_middle", test_sorry_in_middle),
("sorry_in_induction", test_sorry_in_induction), ("sorry_in_induction", test_sorry_in_induction),
("sorry_in_coupled", test_sorry_in_coupled),
] ]
tests.map (fun (name, test) => (name, runMetaMSeq env $ runTest test)) tests.map (fun (name, test) => (name, runMetaMSeq env $ runTest test))