fix: Extraction of sorry's from nested tactics

This commit is contained in:
Leni Aniva 2024-10-03 01:29:46 -07:00
parent 18cd1d0388
commit 143cd289bb
Signed by: aniva
GPG Key ID: 4D9B1C8D10EA4C50
3 changed files with 187 additions and 48 deletions

View File

@ -124,12 +124,12 @@ protected def ppExpr (t : TacticInvocation) (e : Expr) : IO Format :=
end TacticInvocation end TacticInvocation
/-- Analogue of `Lean.Elab.InfoTree.findInfo?`, but that returns a list of all results. -/ /-- Analogue of `Lean.Elab.InfoTree.findInfo?`, but that returns a list of all results. -/
partial def findAllInfo (t : Elab.InfoTree) (ctx : Option Elab.ContextInfo) (pred : Elab.Info → Bool) : partial def findAllInfo (t : Elab.InfoTree) (context?: Option Elab.ContextInfo) (pred : Elab.Info → Bool) :
List (Elab.Info × Option Elab.ContextInfo × PersistentArray Elab.InfoTree) := List (Elab.Info × Option Elab.ContextInfo × PersistentArray Elab.InfoTree) :=
match t with match t with
| .context inner t => findAllInfo t (inner.mergeIntoOuter? ctx) pred | .context inner t => findAllInfo t (inner.mergeIntoOuter? context?) pred
| .node i children => | .node i children =>
(if pred i then [(i, ctx, children)] else []) ++ children.toList.bind (fun t => findAllInfo t ctx pred) (if pred i then [(i, context?, children)] else []) ++ children.toList.bind (fun t => findAllInfo t context? pred)
| _ => [] | _ => []
/-- Return all `TacticInfo` nodes in an `InfoTree` corresponding to tactics, /-- Return all `TacticInfo` nodes in an `InfoTree` corresponding to tactics,
@ -159,7 +159,11 @@ def collectTacticsFromCompilationStep (step : CompilationStep) : IO (List Protoc
return t.pretty return t.pretty
return { goalBefore, goalAfter, tactic } return { goalBefore, goalAfter, tactic }
private def collectSorrysInTree (t : Elab.InfoTree) : List Elab.Info := structure InfoWithContext where
info: Elab.Info
context?: Option Elab.ContextInfo := .none
private def collectSorrysInTree (t : Elab.InfoTree) : List InfoWithContext :=
let infos := findAllInfo t none fun i => match i with let infos := findAllInfo t none fun i => match i with
| .ofTermInfo { expectedType?, expr, stx, .. } => | .ofTermInfo { expectedType?, expr, stx, .. } =>
expr.isSorry ∧ expectedType?.isSome ∧ stx.isOfKind `Lean.Parser.Term.sorry expr.isSorry ∧ expectedType?.isSome ∧ stx.isOfKind `Lean.Parser.Term.sorry
@ -167,11 +171,11 @@ private def collectSorrysInTree (t : Elab.InfoTree) : List Elab.Info :=
-- The `sorry` term is distinct from the `sorry` tactic -- The `sorry` term is distinct from the `sorry` tactic
stx.isOfKind `Lean.Parser.Tactic.tacticSorry stx.isOfKind `Lean.Parser.Tactic.tacticSorry
| _ => false | _ => false
infos.map fun (i, _, _) => i infos.map fun (info, context?, _) => { info, context? }
-- NOTE: Plural deliberately not spelled "sorries" -- NOTE: Plural deliberately not spelled "sorries"
@[export pantograph_frontend_collect_sorrys_m] @[export pantograph_frontend_collect_sorrys_m]
def collectSorrys (step: CompilationStep) : List Elab.Info := def collectSorrys (step: CompilationStep) : List InfoWithContext :=
step.trees.bind collectSorrysInTree step.trees.bind collectSorrysInTree
@ -181,55 +185,106 @@ structure Context where
sourceMCtx : MetavarContext := {} sourceMCtx : MetavarContext := {}
sourceLCtx : LocalContext := {} sourceLCtx : LocalContext := {}
structure State where
-- Stores mapping from old to new mvar/fvars
mvarMap: HashMap MVarId MVarId := {}
fvarMap: HashMap FVarId FVarId := {}
/- /-
Monadic state for translating a frozen meta state. The underlying `MetaM` Monadic state for translating a frozen meta state. The underlying `MetaM`
operates in the "target" context and state. operates in the "target" context and state.
-/ -/
abbrev MetaTranslateM := ReaderT Context MetaM abbrev MetaTranslateM := ReaderT Context StateRefT State MetaM
def getSourceLCtx : MetaTranslateM LocalContext := do pure (← read).sourceLCtx def getSourceLCtx : MetaTranslateM LocalContext := do pure (← read).sourceLCtx
def getSourceMCtx : MetaTranslateM MetavarContext := do pure (← read).sourceMCtx def getSourceMCtx : MetaTranslateM MetavarContext := do pure (← read).sourceMCtx
def addTranslatedFVar (src dst: FVarId) : MetaTranslateM Unit := do
let state ← get
set { state with fvarMap := state.fvarMap.insert src dst }
def addTranslatedMVar (src dst: MVarId) : MetaTranslateM Unit := do
let state ← get
set { state with mvarMap := state.mvarMap.insert src dst }
private def translateExpr (expr: Expr) : MetaTranslateM Expr := do def resetFVarMap : MetaTranslateM Unit := do
let (expr, _) := instantiateMVarsCore (mctx := ← getSourceMCtx) expr let state ← get
return expr set { state with fvarMap := {} }
def translateLocalDecl (frozenLocalDecl: LocalDecl) : MetaTranslateM LocalDecl := do private partial def translateExpr (srcExpr: Expr) : MetaTranslateM Expr := do
let (srcExpr, _) := instantiateMVarsCore (mctx := ← getSourceMCtx) srcExpr
--IO.println s!"Transform src: {srcExpr}"
let result ← Core.transform srcExpr λ e => do
let state ← get
match e with
| .fvar fvarId =>
let .some fvarId' := state.fvarMap.find? fvarId | panic! s!"FVar id not registered: {fvarId.name}"
return .done $ .fvar fvarId'
| .mvar mvarId => do
match state.mvarMap.find? mvarId with
| .some mvarId' => do
return .done $ .mvar mvarId'
| .none => do
--let t := (← getSourceMCtx).findDecl? mvarId |>.get!.type
--let t' ← translateExpr t
let mvar' ← Meta.mkFreshExprMVar .none
addTranslatedMVar mvarId mvar'.mvarId!
return .done mvar'
| _ => return .continue
try
Meta.check result
catch ex =>
panic! s!"Check failed: {← ex.toMessageData.toString}"
return result
def translateLocalDecl (srcLocalDecl: LocalDecl) : MetaTranslateM LocalDecl := do
let fvarId ← mkFreshFVarId let fvarId ← mkFreshFVarId
match frozenLocalDecl with addTranslatedFVar srcLocalDecl.fvarId fvarId
| .cdecl index _ userName type bi kind => match srcLocalDecl with
return .cdecl index fvarId userName type bi kind | .cdecl index _ userName type bi kind => do
| .ldecl index _ userName type value nonDep kind => --IO.println s!"[CD] {userName} {toString type}"
return .ldecl index fvarId userName type value nonDep kind return .cdecl index fvarId userName (← translateExpr type) bi kind
| .ldecl index _ userName type value nonDep kind => do
--IO.println s!"[LD] {toString type} := {toString value}"
return .ldecl index fvarId userName (← translateExpr type) (← translateExpr value) nonDep kind
def translateMVarId (mvarId: MVarId) : MetaTranslateM MVarId := do def translateLCtx : MetaTranslateM LocalContext := do
let shadowDecl := (← getSourceMCtx).findDecl? mvarId |>.get! resetFVarMap
let target ← translateExpr shadowDecl.type (← getSourceLCtx).foldlM (λ lctx srcLocalDecl => do
let mvar ← withTheReader Context (λ ctx => { ctx with sourceLCtx := shadowDecl.lctx }) do let localDecl ← Meta.withLCtx lctx #[] do translateLocalDecl srcLocalDecl
let lctx ← MonadLCtx.getLCtx pure $ lctx.addDecl localDecl
let lctx ← (← getSourceLCtx).foldlM (λ lctx frozenLocalDecl => do ) (← MonadLCtx.getLCtx)
let localDecl ← translateLocalDecl frozenLocalDecl
let lctx := lctx.addDecl localDecl
pure lctx def translateMVarId (srcMVarId: MVarId) : MetaTranslateM MVarId := do
) lctx let srcDecl := (← getSourceMCtx).findDecl? srcMVarId |>.get!
withTheReader Meta.Context (fun ctx => { ctx with lctx }) do let mvar ← withTheReader Context (λ ctx => { ctx with sourceLCtx := srcDecl.lctx }) do
Meta.mkFreshExprSyntheticOpaqueMVar target let lctx' ← translateLCtx
Meta.withLCtx lctx' #[] do
let target' ← translateExpr srcDecl.type
Meta.mkFreshExprSyntheticOpaqueMVar target'
addTranslatedMVar srcMVarId mvar.mvarId!
return mvar.mvarId! return mvar.mvarId!
def translateTermInfo (termInfo: Elab.TermInfo) : MetaM MVarId := do def translateMVarFromTermInfo (termInfo : Elab.TermInfo) (context? : Option Elab.ContextInfo)
: MetaM MVarId := do
let trM : MetaTranslateM MVarId := do let trM : MetaTranslateM MVarId := do
let type := termInfo.expectedType?.get! let type := termInfo.expectedType?.get!
let lctx ← getSourceLCtx let lctx' ← translateLCtx
let mvar ← withTheReader Meta.Context (fun ctx => { ctx with lctx }) do let mvar ← Meta.withLCtx lctx' #[] do
Meta.mkFreshExprSyntheticOpaqueMVar type let type' ← translateExpr type
Meta.mkFreshExprSyntheticOpaqueMVar type'
return mvar.mvarId! return mvar.mvarId!
trM.run { sourceLCtx := termInfo.lctx } trM.run {
sourceMCtx := context?.map (·.mctx) |>.getD {},
sourceLCtx := termInfo.lctx } |>.run' {}
def translateTacticInfoBefore (tacticInfo: Elab.TacticInfo) : MetaM (List MVarId) := do def translateMVarFromTacticInfoBefore (tacticInfo : Elab.TacticInfo) (_context? : Option Elab.ContextInfo)
: MetaM (List MVarId) := do
let trM : MetaTranslateM (List MVarId) := do let trM : MetaTranslateM (List MVarId) := do
tacticInfo.goalsBefore.mapM translateMVarId tacticInfo.goalsBefore.mapM translateMVarId
trM.run { sourceMCtx := tacticInfo.mctxBefore } trM.run {
sourceMCtx := tacticInfo.mctxBefore
} |>.run' {}
end MetaTranslate end MetaTranslate
@ -242,15 +297,15 @@ function duplicates frozen mvars in term and tactic info nodes, and add them to
the current `MetavarContext`. the current `MetavarContext`.
-/ -/
@[export pantograph_frontend_sorrys_to_goal_state] @[export pantograph_frontend_sorrys_to_goal_state]
def sorrysToGoalState (sorrys : List Elab.Info) : MetaM GoalState := do def sorrysToGoalState (sorrys : List InfoWithContext) : MetaM GoalState := do
assert! !sorrys.isEmpty assert! !sorrys.isEmpty
let goals ← sorrys.mapM λ info => Meta.withLCtx info.lctx #[] do let goals ← sorrys.mapM λ i => do
match info with match i.info with
| .ofTermInfo termInfo => do | .ofTermInfo termInfo => do
let mvarId ← MetaTranslate.translateTermInfo termInfo let mvarId ← MetaTranslate.translateMVarFromTermInfo termInfo i.context?
return [mvarId] return [mvarId]
| .ofTacticInfo tacticInfo => do | .ofTacticInfo tacticInfo => do
MetaTranslate.translateTacticInfoBefore tacticInfo MetaTranslate.translateMVarFromTacticInfoBefore tacticInfo i.context?
| _ => panic! "Invalid info" | _ => panic! "Invalid info"
let goals := goals.bind id let goals := goals.bind id
let root := match goals with let root := match goals with

View File

@ -19,7 +19,7 @@ def collectSorrysFromSource (source: String) : MetaM (List GoalState) := do
return .some goalState return .some goalState
return goalStates return goalStates
def test_multiple_sorries_in_proof : TestT MetaM Unit := do def test_multiple_sorrys_in_proof : TestT MetaM Unit := do
let sketch := " let sketch := "
theorem plus_n_Sm_proved_formal_sketch : ∀ n m : Nat, n + (m + 1) = (n + m) + 1 := by theorem plus_n_Sm_proved_formal_sketch : ∀ n m : Nat, n + (m + 1) = (n + m) + 1 := by
have h_nat_add_succ: ∀ n m : Nat, n = m := sorry have h_nat_add_succ: ∀ n m : Nat, n = m := sorry
@ -27,28 +27,112 @@ theorem plus_n_Sm_proved_formal_sketch : ∀ n m : Nat, n + (m + 1) = (n + m) +
" "
let goalStates ← (collectSorrysFromSource sketch).run' {} let goalStates ← (collectSorrysFromSource sketch).run' {}
let [goalState] := goalStates | panic! "Illegal number of states" let [goalState] := goalStates | panic! "Illegal number of states"
addTest $ LSpec.check "plus_n_Sm" ((← goalState.serializeGoals (options := {})) = #[ addTest $ LSpec.check "plus_n_Sm" ((← goalState.serializeGoals (options := {})).map (·.devolatilize) = #[
{ {
name := "_uniq.1",
target := { pp? := "∀ (n m : Nat), n = m" }, target := { pp? := "∀ (n m : Nat), n = m" },
vars := #[ vars := #[
] ]
}, },
{ {
name := "_uniq.4",
target := { pp? := "∀ (n m : Nat), n + (m + 1) = n + m + 1" }, target := { pp? := "∀ (n m : Nat), n + (m + 1) = n + m + 1" },
vars := #[{ vars := #[{
name := "_uniq.3",
userName := "h_nat_add_succ", userName := "h_nat_add_succ",
type? := .some { pp? := "∀ (n m : Nat), n = m" }, type? := .some { pp? := "∀ (n m : Nat), n = m" },
}], }],
} }
]) ])
def test_sorry_in_middle: TestT MetaM Unit := do
let sketch := "
example : ∀ (n m: Nat), n + m = m + n := by
intros n m
sorry
"
let goalStates ← (collectSorrysFromSource sketch).run' {}
let [goalState] := goalStates | panic! s!"Illegal number of states: {goalStates.length}"
addTest $ LSpec.check "plus_n_Sm" ((← goalState.serializeGoals (options := {})).map (·.devolatilize) = #[
{
target := { pp? := "n + m = m + n" },
vars := #[{
userName := "n",
type? := .some { pp? := "Nat" },
}, {
userName := "m",
type? := .some { pp? := "Nat" },
}
],
}
])
def test_sorry_in_induction : TestT MetaM Unit := do
let sketch := "
example : ∀ (n m: Nat), n + m = m + n := by
intros n m
induction n with
| zero =>
have h1 : 0 + m = m := sorry
sorry
| succ n ih =>
have h2 : n + m = m := sorry
sorry
"
let goalStates ← (collectSorrysFromSource sketch).run' {}
let [goalState] := goalStates | panic! s!"Illegal number of states: {goalStates.length}"
addTest $ LSpec.check "plus_n_Sm" ((← goalState.serializeGoals (options := {})).map (·.devolatilize) = #[
{
target := { pp? := "0 + m = m" },
vars := #[{
userName := "m",
type? := .some { pp? := "Nat" },
}]
},
{
target := { pp? := "0 + m = m + 0" },
vars := #[{
userName := "m",
type? := .some { pp? := "Nat" },
}, {
userName := "h1",
type? := .some { pp? := "0 + m = m" },
}]
},
{
target := { pp? := "n + m = m" },
vars := #[{
userName := "m",
type? := .some { pp? := "Nat" },
}, {
userName := "n",
type? := .some { pp? := "Nat" },
}, {
userName := "ih",
type? := .some { pp? := "n + m = m + n" },
}]
},
{
target := { pp? := "n + 1 + m = m + (n + 1)" },
vars := #[{
userName := "m",
type? := .some { pp? := "Nat" },
}, {
userName := "n",
type? := .some { pp? := "Nat" },
}, {
userName := "ih",
type? := .some { pp? := "n + m = m + n" },
}, {
userName := "h2",
type? := .some { pp? := "n + m = m" },
}]
}
])
def suite (env : Environment): List (String × IO LSpec.TestSeq) := def suite (env : Environment): List (String × IO LSpec.TestSeq) :=
let tests := [ let tests := [
("multiple_sorrys_in_proof", test_multiple_sorries_in_proof), ("multiple_sorrys_in_proof", test_multiple_sorrys_in_proof),
("sorry_in_middle", test_sorry_in_middle),
("sorry_in_induction", test_sorry_in_induction),
] ]
tests.map (fun (name, test) => (name, runMetaMSeq env $ runTest test)) tests.map (fun (name, test) => (name, runMetaMSeq env $ runTest test))

View File

@ -201,9 +201,9 @@ def test_frontend_process_sorry : Test :=
[ [
let file := s!"{solved}{withSorry}" let file := s!"{solved}{withSorry}"
let goal1: Protocol.Goal := { let goal1: Protocol.Goal := {
name := "_uniq.1", name := "_uniq.6",
target := { pp? := .some "p → p" }, target := { pp? := .some "p → p" },
vars := #[{ name := "_uniq.168", userName := "p", type? := .some { pp? := .some "Prop" }}], vars := #[{ name := "_uniq.4", userName := "p", type? := .some { pp? := .some "Prop" }}],
} }
step "frontend.process" step "frontend.process"
[ [