diff --git a/Pantograph/Frontend/Elab.lean b/Pantograph/Frontend/Elab.lean index ec86df3..2e0c14e 100644 --- a/Pantograph/Frontend/Elab.lean +++ b/Pantograph/Frontend/Elab.lean @@ -124,12 +124,12 @@ protected def ppExpr (t : TacticInvocation) (e : Expr) : IO Format := end TacticInvocation /-- Analogue of `Lean.Elab.InfoTree.findInfo?`, but that returns a list of all results. -/ -partial def findAllInfo (t : Elab.InfoTree) (ctx : Option Elab.ContextInfo) (pred : Elab.Info → Bool) : +partial def findAllInfo (t : Elab.InfoTree) (context?: Option Elab.ContextInfo) (pred : Elab.Info → Bool) : List (Elab.Info × Option Elab.ContextInfo × PersistentArray Elab.InfoTree) := match t with - | .context inner t => findAllInfo t (inner.mergeIntoOuter? ctx) pred + | .context inner t => findAllInfo t (inner.mergeIntoOuter? context?) pred | .node i children => - (if pred i then [(i, ctx, children)] else []) ++ children.toList.bind (fun t => findAllInfo t ctx pred) + (if pred i then [(i, context?, children)] else []) ++ children.toList.bind (fun t => findAllInfo t context? pred) | _ => [] /-- Return all `TacticInfo` nodes in an `InfoTree` corresponding to tactics, @@ -159,7 +159,11 @@ def collectTacticsFromCompilationStep (step : CompilationStep) : IO (List Protoc return t.pretty return { goalBefore, goalAfter, tactic } -private def collectSorrysInTree (t : Elab.InfoTree) : List Elab.Info := +structure InfoWithContext where + info: Elab.Info + context?: Option Elab.ContextInfo := .none + +private def collectSorrysInTree (t : Elab.InfoTree) : List InfoWithContext := let infos := findAllInfo t none fun i => match i with | .ofTermInfo { expectedType?, expr, stx, .. } => expr.isSorry ∧ expectedType?.isSome ∧ stx.isOfKind `Lean.Parser.Term.sorry @@ -167,11 +171,11 @@ private def collectSorrysInTree (t : Elab.InfoTree) : List Elab.Info := -- The `sorry` term is distinct from the `sorry` tactic stx.isOfKind `Lean.Parser.Tactic.tacticSorry | _ => false - infos.map fun (i, _, _) => i + infos.map fun (info, context?, _) => { info, context? } -- NOTE: Plural deliberately not spelled "sorries" @[export pantograph_frontend_collect_sorrys_m] -def collectSorrys (step: CompilationStep) : List Elab.Info := +def collectSorrys (step: CompilationStep) : List InfoWithContext := step.trees.bind collectSorrysInTree @@ -181,55 +185,106 @@ structure Context where sourceMCtx : MetavarContext := {} sourceLCtx : LocalContext := {} +structure State where + -- Stores mapping from old to new mvar/fvars + mvarMap: HashMap MVarId MVarId := {} + fvarMap: HashMap FVarId FVarId := {} + /- Monadic state for translating a frozen meta state. The underlying `MetaM` operates in the "target" context and state. -/ -abbrev MetaTranslateM := ReaderT Context MetaM +abbrev MetaTranslateM := ReaderT Context StateRefT State MetaM def getSourceLCtx : MetaTranslateM LocalContext := do pure (← read).sourceLCtx def getSourceMCtx : MetaTranslateM MetavarContext := do pure (← read).sourceMCtx +def addTranslatedFVar (src dst: FVarId) : MetaTranslateM Unit := do + let state ← get + set { state with fvarMap := state.fvarMap.insert src dst } +def addTranslatedMVar (src dst: MVarId) : MetaTranslateM Unit := do + let state ← get + set { state with mvarMap := state.mvarMap.insert src dst } -private def translateExpr (expr: Expr) : MetaTranslateM Expr := do - let (expr, _) := instantiateMVarsCore (mctx := ← getSourceMCtx) expr - return expr +def resetFVarMap : MetaTranslateM Unit := do + let state ← get + set { state with fvarMap := {} } -def translateLocalDecl (frozenLocalDecl: LocalDecl) : MetaTranslateM LocalDecl := do +private partial def translateExpr (srcExpr: Expr) : MetaTranslateM Expr := do + let (srcExpr, _) := instantiateMVarsCore (mctx := ← getSourceMCtx) srcExpr + --IO.println s!"Transform src: {srcExpr}" + let result ← Core.transform srcExpr λ e => do + let state ← get + match e with + | .fvar fvarId => + let .some fvarId' := state.fvarMap.find? fvarId | panic! s!"FVar id not registered: {fvarId.name}" + return .done $ .fvar fvarId' + | .mvar mvarId => do + match state.mvarMap.find? mvarId with + | .some mvarId' => do + return .done $ .mvar mvarId' + | .none => do + --let t := (← getSourceMCtx).findDecl? mvarId |>.get!.type + --let t' ← translateExpr t + let mvar' ← Meta.mkFreshExprMVar .none + addTranslatedMVar mvarId mvar'.mvarId! + return .done mvar' + | _ => return .continue + try + Meta.check result + catch ex => + panic! s!"Check failed: {← ex.toMessageData.toString}" + return result + +def translateLocalDecl (srcLocalDecl: LocalDecl) : MetaTranslateM LocalDecl := do let fvarId ← mkFreshFVarId - match frozenLocalDecl with - | .cdecl index _ userName type bi kind => - return .cdecl index fvarId userName type bi kind - | .ldecl index _ userName type value nonDep kind => - return .ldecl index fvarId userName type value nonDep kind + addTranslatedFVar srcLocalDecl.fvarId fvarId + match srcLocalDecl with + | .cdecl index _ userName type bi kind => do + --IO.println s!"[CD] {userName} {toString type}" + return .cdecl index fvarId userName (← translateExpr type) bi kind + | .ldecl index _ userName type value nonDep kind => do + --IO.println s!"[LD] {toString type} := {toString value}" + return .ldecl index fvarId userName (← translateExpr type) (← translateExpr value) nonDep kind -def translateMVarId (mvarId: MVarId) : MetaTranslateM MVarId := do - let shadowDecl := (← getSourceMCtx).findDecl? mvarId |>.get! - let target ← translateExpr shadowDecl.type - let mvar ← withTheReader Context (λ ctx => { ctx with sourceLCtx := shadowDecl.lctx }) do - let lctx ← MonadLCtx.getLCtx - let lctx ← (← getSourceLCtx).foldlM (λ lctx frozenLocalDecl => do - let localDecl ← translateLocalDecl frozenLocalDecl - let lctx := lctx.addDecl localDecl - pure lctx - ) lctx - withTheReader Meta.Context (fun ctx => { ctx with lctx }) do - Meta.mkFreshExprSyntheticOpaqueMVar target +def translateLCtx : MetaTranslateM LocalContext := do + resetFVarMap + (← getSourceLCtx).foldlM (λ lctx srcLocalDecl => do + let localDecl ← Meta.withLCtx lctx #[] do translateLocalDecl srcLocalDecl + pure $ lctx.addDecl localDecl + ) (← MonadLCtx.getLCtx) + + +def translateMVarId (srcMVarId: MVarId) : MetaTranslateM MVarId := do + let srcDecl := (← getSourceMCtx).findDecl? srcMVarId |>.get! + let mvar ← withTheReader Context (λ ctx => { ctx with sourceLCtx := srcDecl.lctx }) do + let lctx' ← translateLCtx + Meta.withLCtx lctx' #[] do + let target' ← translateExpr srcDecl.type + Meta.mkFreshExprSyntheticOpaqueMVar target' + addTranslatedMVar srcMVarId mvar.mvarId! return mvar.mvarId! -def translateTermInfo (termInfo: Elab.TermInfo) : MetaM MVarId := do +def translateMVarFromTermInfo (termInfo : Elab.TermInfo) (context? : Option Elab.ContextInfo) + : MetaM MVarId := do let trM : MetaTranslateM MVarId := do let type := termInfo.expectedType?.get! - let lctx ← getSourceLCtx - let mvar ← withTheReader Meta.Context (fun ctx => { ctx with lctx }) do - Meta.mkFreshExprSyntheticOpaqueMVar type + let lctx' ← translateLCtx + let mvar ← Meta.withLCtx lctx' #[] do + let type' ← translateExpr type + Meta.mkFreshExprSyntheticOpaqueMVar type' return mvar.mvarId! - trM.run { sourceLCtx := termInfo.lctx } + trM.run { + sourceMCtx := context?.map (·.mctx) |>.getD {}, + sourceLCtx := termInfo.lctx } |>.run' {} -def translateTacticInfoBefore (tacticInfo: Elab.TacticInfo) : MetaM (List MVarId) := do +def translateMVarFromTacticInfoBefore (tacticInfo : Elab.TacticInfo) (_context? : Option Elab.ContextInfo) + : MetaM (List MVarId) := do let trM : MetaTranslateM (List MVarId) := do tacticInfo.goalsBefore.mapM translateMVarId - trM.run { sourceMCtx := tacticInfo.mctxBefore } + trM.run { + sourceMCtx := tacticInfo.mctxBefore + } |>.run' {} end MetaTranslate @@ -242,15 +297,15 @@ function duplicates frozen mvars in term and tactic info nodes, and add them to the current `MetavarContext`. -/ @[export pantograph_frontend_sorrys_to_goal_state] -def sorrysToGoalState (sorrys : List Elab.Info) : MetaM GoalState := do +def sorrysToGoalState (sorrys : List InfoWithContext) : MetaM GoalState := do assert! !sorrys.isEmpty - let goals ← sorrys.mapM λ info => Meta.withLCtx info.lctx #[] do - match info with + let goals ← sorrys.mapM λ i => do + match i.info with | .ofTermInfo termInfo => do - let mvarId ← MetaTranslate.translateTermInfo termInfo + let mvarId ← MetaTranslate.translateMVarFromTermInfo termInfo i.context? return [mvarId] | .ofTacticInfo tacticInfo => do - MetaTranslate.translateTacticInfoBefore tacticInfo + MetaTranslate.translateMVarFromTacticInfoBefore tacticInfo i.context? | _ => panic! "Invalid info" let goals := goals.bind id let root := match goals with diff --git a/Test/Frontend.lean b/Test/Frontend.lean index ac347e6..c186503 100644 --- a/Test/Frontend.lean +++ b/Test/Frontend.lean @@ -19,7 +19,7 @@ def collectSorrysFromSource (source: String) : MetaM (List GoalState) := do return .some goalState return goalStates -def test_multiple_sorries_in_proof : TestT MetaM Unit := do +def test_multiple_sorrys_in_proof : TestT MetaM Unit := do let sketch := " theorem plus_n_Sm_proved_formal_sketch : ∀ n m : Nat, n + (m + 1) = (n + m) + 1 := by have h_nat_add_succ: ∀ n m : Nat, n = m := sorry @@ -27,28 +27,112 @@ theorem plus_n_Sm_proved_formal_sketch : ∀ n m : Nat, n + (m + 1) = (n + m) + " let goalStates ← (collectSorrysFromSource sketch).run' {} let [goalState] := goalStates | panic! "Illegal number of states" - addTest $ LSpec.check "plus_n_Sm" ((← goalState.serializeGoals (options := {})) = #[ + addTest $ LSpec.check "plus_n_Sm" ((← goalState.serializeGoals (options := {})).map (·.devolatilize) = #[ { - name := "_uniq.1", target := { pp? := "∀ (n m : Nat), n = m" }, vars := #[ ] }, { - name := "_uniq.4", target := { pp? := "∀ (n m : Nat), n + (m + 1) = n + m + 1" }, vars := #[{ - name := "_uniq.3", userName := "h_nat_add_succ", type? := .some { pp? := "∀ (n m : Nat), n = m" }, }], } ]) +def test_sorry_in_middle: TestT MetaM Unit := do + let sketch := " +example : ∀ (n m: Nat), n + m = m + n := by + intros n m + sorry + " + let goalStates ← (collectSorrysFromSource sketch).run' {} + let [goalState] := goalStates | panic! s!"Illegal number of states: {goalStates.length}" + addTest $ LSpec.check "plus_n_Sm" ((← goalState.serializeGoals (options := {})).map (·.devolatilize) = #[ + { + target := { pp? := "n + m = m + n" }, + vars := #[{ + userName := "n", + type? := .some { pp? := "Nat" }, + }, { + userName := "m", + type? := .some { pp? := "Nat" }, + } + ], + } + ]) + +def test_sorry_in_induction : TestT MetaM Unit := do + let sketch := " +example : ∀ (n m: Nat), n + m = m + n := by + intros n m + induction n with + | zero => + have h1 : 0 + m = m := sorry + sorry + | succ n ih => + have h2 : n + m = m := sorry + sorry + " + let goalStates ← (collectSorrysFromSource sketch).run' {} + let [goalState] := goalStates | panic! s!"Illegal number of states: {goalStates.length}" + addTest $ LSpec.check "plus_n_Sm" ((← goalState.serializeGoals (options := {})).map (·.devolatilize) = #[ + { + target := { pp? := "0 + m = m" }, + vars := #[{ + userName := "m", + type? := .some { pp? := "Nat" }, + }] + }, + { + target := { pp? := "0 + m = m + 0" }, + vars := #[{ + userName := "m", + type? := .some { pp? := "Nat" }, + }, { + userName := "h1", + type? := .some { pp? := "0 + m = m" }, + }] + }, + { + target := { pp? := "n + m = m" }, + vars := #[{ + userName := "m", + type? := .some { pp? := "Nat" }, + }, { + userName := "n", + type? := .some { pp? := "Nat" }, + }, { + userName := "ih", + type? := .some { pp? := "n + m = m + n" }, + }] + }, + { + target := { pp? := "n + 1 + m = m + (n + 1)" }, + vars := #[{ + userName := "m", + type? := .some { pp? := "Nat" }, + }, { + userName := "n", + type? := .some { pp? := "Nat" }, + }, { + userName := "ih", + type? := .some { pp? := "n + m = m + n" }, + }, { + userName := "h2", + type? := .some { pp? := "n + m = m" }, + }] + } + ]) + def suite (env : Environment): List (String × IO LSpec.TestSeq) := let tests := [ - ("multiple_sorrys_in_proof", test_multiple_sorries_in_proof), + ("multiple_sorrys_in_proof", test_multiple_sorrys_in_proof), + ("sorry_in_middle", test_sorry_in_middle), + ("sorry_in_induction", test_sorry_in_induction), ] tests.map (fun (name, test) => (name, runMetaMSeq env $ runTest test)) diff --git a/Test/Integration.lean b/Test/Integration.lean index 4a8e418..b3d49fe 100644 --- a/Test/Integration.lean +++ b/Test/Integration.lean @@ -201,9 +201,9 @@ def test_frontend_process_sorry : Test := [ let file := s!"{solved}{withSorry}" let goal1: Protocol.Goal := { - name := "_uniq.1", + name := "_uniq.6", target := { pp? := .some "p → p" }, - vars := #[{ name := "_uniq.168", userName := "p", type? := .some { pp? := .some "Prop" }}], + vars := #[{ name := "_uniq.4", userName := "p", type? := .some { pp? := .some "Prop" }}], } step "frontend.process" [