From 530a1a1a97273314bd7b01c542ce686a366aa0b9 Mon Sep 17 00:00:00 2001 From: Leni Aniva Date: Thu, 3 Oct 2024 11:35:54 -0700 Subject: [PATCH] fix: Extracting `sorry`s from coupled goals --- Pantograph/Expr.lean | 87 ++++++++--------- Pantograph/Frontend.lean | 2 +- Pantograph/Frontend/Elab.lean | 118 +--------------------- Pantograph/Frontend/MetaTranslate.lean | 129 +++++++++++++++++++++++++ Test/Frontend.lean | 49 ++++++++-- 5 files changed, 221 insertions(+), 164 deletions(-) create mode 100644 Pantograph/Frontend/MetaTranslate.lean diff --git a/Pantograph/Expr.lean b/Pantograph/Expr.lean index f989575..ad064a7 100644 --- a/Pantograph/Expr.lean +++ b/Pantograph/Expr.lean @@ -60,53 +60,54 @@ partial def instantiateDelayedMVars (eOrig: Expr) : MetaM Expr := do -- nested mvars. mvarId.setKind .syntheticOpaque - let lctx ← MonadLCtx.getLCtx - if mvarDecl.lctx.any (λ decl => !lctx.contains decl.fvarId) then - let violations := mvarDecl.lctx.decls.foldl (λ acc decl? => match decl? with - | .some decl => if lctx.contains decl.fvarId then acc else acc ++ [decl.fvarId.name] - | .none => acc) [] - panic! s!"Local context variable violation: {violations}" + mvarId.withContext do + let lctx ← MonadLCtx.getLCtx + if mvarDecl.lctx.any (λ decl => !lctx.contains decl.fvarId) then + let violations := mvarDecl.lctx.decls.foldl (λ acc decl? => match decl? with + | .some decl => if lctx.contains decl.fvarId then acc else acc ++ [decl.fvarId.name] + | .none => acc) [] + panic! s!"In the context of {mvarId.name}, there are local context variable violations: {violations}" - if let .some assign ← getExprMVarAssignment? mvarId then - --IO.println s!"{padding}├A ?{mvarId.name}" - assert! !(← mvarId.isDelayedAssigned) - return .visit (mkAppN assign args) - else if let some { fvars, mvarIdPending } ← getDelayedMVarAssignment? mvarId then - --let substTableStr := String.intercalate ", " $ Array.zipWith fvars args (λ fvar assign => s!"{fvar.fvarId!.name} := {assign}") |>.toList - --IO.println s!"{padding}├MD ?{mvarId.name} := ?{mvarIdPending.name} [{substTableStr}]" + if let .some assign ← getExprMVarAssignment? mvarId then + --IO.println s!"{padding}├A ?{mvarId.name}" + assert! !(← mvarId.isDelayedAssigned) + return .visit (mkAppN assign args) + else if let some { fvars, mvarIdPending } ← getDelayedMVarAssignment? mvarId then + --let substTableStr := String.intercalate ", " $ Array.zipWith fvars args (λ fvar assign => s!"{fvar.fvarId!.name} := {assign}") |>.toList + --IO.println s!"{padding}├MD ?{mvarId.name} := ?{mvarIdPending.name} [{substTableStr}]" - if args.size < fvars.size then - throwError "Not enough arguments to instantiate a delay assigned mvar. This is due to bad implementations of a tactic: {args.size} < {fvars.size}. Expr: {toString e}; Origin: {toString eOrig}" - --if !args.isEmpty then - --IO.println s!"{padding}├── Arguments Begin" - let args ← args.mapM self - --if !args.isEmpty then - --IO.println s!"{padding}├── Arguments End" - if !(← mvarIdPending.isAssignedOrDelayedAssigned) then - --IO.println s!"{padding}├T1" - let result := mkAppN f args + if args.size < fvars.size then + throwError "Not enough arguments to instantiate a delay assigned mvar. This is due to bad implementations of a tactic: {args.size} < {fvars.size}. Expr: {toString e}; Origin: {toString eOrig}" + --if !args.isEmpty then + --IO.println s!"{padding}├── Arguments Begin" + let args ← args.mapM self + --if !args.isEmpty then + --IO.println s!"{padding}├── Arguments End" + if !(← mvarIdPending.isAssignedOrDelayedAssigned) then + --IO.println s!"{padding}├T1" + let result := mkAppN f args + return .done result + + let pending ← mvarIdPending.withContext do + let inner ← instantiateDelayedMVars (.mvar mvarIdPending) --(level := level + 1) + --IO.println s!"{padding}├Pre: {inner}" + pure <| (← inner.abstractM fvars).instantiateRev args + + -- Tail arguments + let result := mkAppRange pending fvars.size args.size args + --IO.println s!"{padding}├MD {result}" return .done result + else + assert! !(← mvarId.isAssigned) + assert! !(← mvarId.isDelayedAssigned) + --if !args.isEmpty then + -- IO.println s!"{padding}├── Arguments Begin" + let args ← args.mapM self + --if !args.isEmpty then + -- IO.println s!"{padding}├── Arguments End" - let pending ← mvarIdPending.withContext do - let inner ← instantiateDelayedMVars (.mvar mvarIdPending) --(level := level + 1) - --IO.println s!"{padding}├Pre: {inner}" - pure <| (← inner.abstractM fvars).instantiateRev args - - -- Tail arguments - let result := mkAppRange pending fvars.size args.size args - --IO.println s!"{padding}├MD {result}" - return .done result - else - assert! !(← mvarId.isAssigned) - assert! !(← mvarId.isDelayedAssigned) - --if !args.isEmpty then - -- IO.println s!"{padding}├── Arguments Begin" - let args ← args.mapM self - --if !args.isEmpty then - -- IO.println s!"{padding}├── Arguments End" - - --IO.println s!"{padding}├M ?{mvarId.name}" - return .done (mkAppN f args)) + --IO.println s!"{padding}├M ?{mvarId.name}" + return .done (mkAppN f args)) --IO.println s!"{padding}└Result {result}" return result where diff --git a/Pantograph/Frontend.lean b/Pantograph/Frontend.lean index ffeeec5..fd91823 100644 --- a/Pantograph/Frontend.lean +++ b/Pantograph/Frontend.lean @@ -1,4 +1,4 @@ /- Adapted from lean-training-data by semorrison -/ -import Pantograph.Protocol import Pantograph.Frontend.Basic import Pantograph.Frontend.Elab +import Pantograph.Frontend.MetaTranslate diff --git a/Pantograph/Frontend/Elab.lean b/Pantograph/Frontend/Elab.lean index 2e0c14e..2036aea 100644 --- a/Pantograph/Frontend/Elab.lean +++ b/Pantograph/Frontend/Elab.lean @@ -3,9 +3,10 @@ import Lean.Elab.Import import Lean.Elab.Command import Lean.Elab.InfoTree -import Pantograph.Protocol import Pantograph.Frontend.Basic +import Pantograph.Frontend.MetaTranslate import Pantograph.Goal +import Pantograph.Protocol open Lean @@ -179,117 +180,6 @@ def collectSorrys (step: CompilationStep) : List InfoWithContext := step.trees.bind collectSorrysInTree -namespace MetaTranslate - -structure Context where - sourceMCtx : MetavarContext := {} - sourceLCtx : LocalContext := {} - -structure State where - -- Stores mapping from old to new mvar/fvars - mvarMap: HashMap MVarId MVarId := {} - fvarMap: HashMap FVarId FVarId := {} - -/- -Monadic state for translating a frozen meta state. The underlying `MetaM` -operates in the "target" context and state. --/ -abbrev MetaTranslateM := ReaderT Context StateRefT State MetaM - -def getSourceLCtx : MetaTranslateM LocalContext := do pure (← read).sourceLCtx -def getSourceMCtx : MetaTranslateM MetavarContext := do pure (← read).sourceMCtx -def addTranslatedFVar (src dst: FVarId) : MetaTranslateM Unit := do - let state ← get - set { state with fvarMap := state.fvarMap.insert src dst } -def addTranslatedMVar (src dst: MVarId) : MetaTranslateM Unit := do - let state ← get - set { state with mvarMap := state.mvarMap.insert src dst } - -def resetFVarMap : MetaTranslateM Unit := do - let state ← get - set { state with fvarMap := {} } - -private partial def translateExpr (srcExpr: Expr) : MetaTranslateM Expr := do - let (srcExpr, _) := instantiateMVarsCore (mctx := ← getSourceMCtx) srcExpr - --IO.println s!"Transform src: {srcExpr}" - let result ← Core.transform srcExpr λ e => do - let state ← get - match e with - | .fvar fvarId => - let .some fvarId' := state.fvarMap.find? fvarId | panic! s!"FVar id not registered: {fvarId.name}" - return .done $ .fvar fvarId' - | .mvar mvarId => do - match state.mvarMap.find? mvarId with - | .some mvarId' => do - return .done $ .mvar mvarId' - | .none => do - --let t := (← getSourceMCtx).findDecl? mvarId |>.get!.type - --let t' ← translateExpr t - let mvar' ← Meta.mkFreshExprMVar .none - addTranslatedMVar mvarId mvar'.mvarId! - return .done mvar' - | _ => return .continue - try - Meta.check result - catch ex => - panic! s!"Check failed: {← ex.toMessageData.toString}" - return result - -def translateLocalDecl (srcLocalDecl: LocalDecl) : MetaTranslateM LocalDecl := do - let fvarId ← mkFreshFVarId - addTranslatedFVar srcLocalDecl.fvarId fvarId - match srcLocalDecl with - | .cdecl index _ userName type bi kind => do - --IO.println s!"[CD] {userName} {toString type}" - return .cdecl index fvarId userName (← translateExpr type) bi kind - | .ldecl index _ userName type value nonDep kind => do - --IO.println s!"[LD] {toString type} := {toString value}" - return .ldecl index fvarId userName (← translateExpr type) (← translateExpr value) nonDep kind - -def translateLCtx : MetaTranslateM LocalContext := do - resetFVarMap - (← getSourceLCtx).foldlM (λ lctx srcLocalDecl => do - let localDecl ← Meta.withLCtx lctx #[] do translateLocalDecl srcLocalDecl - pure $ lctx.addDecl localDecl - ) (← MonadLCtx.getLCtx) - - -def translateMVarId (srcMVarId: MVarId) : MetaTranslateM MVarId := do - let srcDecl := (← getSourceMCtx).findDecl? srcMVarId |>.get! - let mvar ← withTheReader Context (λ ctx => { ctx with sourceLCtx := srcDecl.lctx }) do - let lctx' ← translateLCtx - Meta.withLCtx lctx' #[] do - let target' ← translateExpr srcDecl.type - Meta.mkFreshExprSyntheticOpaqueMVar target' - addTranslatedMVar srcMVarId mvar.mvarId! - return mvar.mvarId! - -def translateMVarFromTermInfo (termInfo : Elab.TermInfo) (context? : Option Elab.ContextInfo) - : MetaM MVarId := do - let trM : MetaTranslateM MVarId := do - let type := termInfo.expectedType?.get! - let lctx' ← translateLCtx - let mvar ← Meta.withLCtx lctx' #[] do - let type' ← translateExpr type - Meta.mkFreshExprSyntheticOpaqueMVar type' - return mvar.mvarId! - trM.run { - sourceMCtx := context?.map (·.mctx) |>.getD {}, - sourceLCtx := termInfo.lctx } |>.run' {} - - -def translateMVarFromTacticInfoBefore (tacticInfo : Elab.TacticInfo) (_context? : Option Elab.ContextInfo) - : MetaM (List MVarId) := do - let trM : MetaTranslateM (List MVarId) := do - tacticInfo.goalsBefore.mapM translateMVarId - trM.run { - sourceMCtx := tacticInfo.mctxBefore - } |>.run' {} - - -end MetaTranslate - -export MetaTranslate (MetaTranslateM) /-- Since we cannot directly merge `MetavarContext`s, we have to get creative. This @@ -299,7 +189,7 @@ the current `MetavarContext`. @[export pantograph_frontend_sorrys_to_goal_state] def sorrysToGoalState (sorrys : List InfoWithContext) : MetaM GoalState := do assert! !sorrys.isEmpty - let goals ← sorrys.mapM λ i => do + let goalsM := sorrys.mapM λ i => do match i.info with | .ofTermInfo termInfo => do let mvarId ← MetaTranslate.translateMVarFromTermInfo termInfo i.context? @@ -307,7 +197,7 @@ def sorrysToGoalState (sorrys : List InfoWithContext) : MetaM GoalState := do | .ofTacticInfo tacticInfo => do MetaTranslate.translateMVarFromTacticInfoBefore tacticInfo i.context? | _ => panic! "Invalid info" - let goals := goals.bind id + let goals := (← goalsM.run {} |>.run' {}).bind id let root := match goals with | [] => panic! "This function cannot be called on an empty list" | [g] => g diff --git a/Pantograph/Frontend/MetaTranslate.lean b/Pantograph/Frontend/MetaTranslate.lean new file mode 100644 index 0000000..82f8dfc --- /dev/null +++ b/Pantograph/Frontend/MetaTranslate.lean @@ -0,0 +1,129 @@ +import Lean.Meta + +open Lean + +namespace Pantograph.Frontend + +namespace MetaTranslate + +structure Context where + sourceMCtx : MetavarContext := {} + sourceLCtx : LocalContext := {} + +abbrev FVarMap := HashMap FVarId FVarId + +structure State where + -- Stores mapping from old to new mvar/fvars + mvarMap: HashMap MVarId MVarId := {} + fvarMap: HashMap FVarId FVarId := {} + +/- +Monadic state for translating a frozen meta state. The underlying `MetaM` +operates in the "target" context and state. +-/ +abbrev MetaTranslateM := ReaderT Context StateRefT State MetaM + +def getSourceLCtx : MetaTranslateM LocalContext := do pure (← read).sourceLCtx +def getSourceMCtx : MetaTranslateM MetavarContext := do pure (← read).sourceMCtx +def addTranslatedFVar (src dst: FVarId) : MetaTranslateM Unit := do + modifyGet λ state => ((), { state with fvarMap := state.fvarMap.insert src dst }) +def addTranslatedMVar (src dst: MVarId) : MetaTranslateM Unit := do + modifyGet λ state => ((), { state with mvarMap := state.mvarMap.insert src dst }) + +def saveFVarMap : MetaTranslateM FVarMap := do + return (← get).fvarMap +def restoreFVarMap (map: FVarMap) : MetaTranslateM Unit := do + modifyGet λ state => ((), { state with fvarMap := map }) +def resetFVarMap : MetaTranslateM Unit := do + modifyGet λ state => ((), { state with fvarMap := {} }) + +mutual +private partial def translateExpr (srcExpr: Expr) : MetaTranslateM Expr := do + let sourceMCtx ← getSourceMCtx + let (srcExpr, _) := instantiateMVarsCore (mctx := sourceMCtx) srcExpr + --IO.println s!"Transform src: {srcExpr}" + let result ← Core.transform srcExpr λ e => do + let state ← get + match e with + | .fvar fvarId => + let .some fvarId' := state.fvarMap.find? fvarId | panic! s!"FVar id not registered: {fvarId.name}" + assert! (← getLCtx).contains fvarId' + return .done $ .fvar fvarId' + | .mvar mvarId => do + assert! !(sourceMCtx.dAssignment.contains mvarId) + assert! !(sourceMCtx.eAssignment.contains mvarId) + match state.mvarMap.find? mvarId with + | .some mvarId' => do + return .done $ .mvar mvarId' + | .none => do + -- Entering another LCtx, must save the current one + let fvarMap ← saveFVarMap + let mvarId' ← translateMVarId mvarId + restoreFVarMap fvarMap + return .done $ .mvar mvarId' + | _ => return .continue + Meta.check result + return result + +partial def translateLocalInstance (srcInstance: LocalInstance) : MetaTranslateM LocalInstance := do + return { + className := srcInstance.className, + fvar := ← translateExpr srcInstance.fvar + } +partial def translateLocalDecl (srcLocalDecl: LocalDecl) : MetaTranslateM LocalDecl := do + let fvarId ← mkFreshFVarId + addTranslatedFVar srcLocalDecl.fvarId fvarId + match srcLocalDecl with + | .cdecl index _ userName type bi kind => do + --IO.println s!"[CD] {userName} {toString type}" + return .cdecl index fvarId userName (← translateExpr type) bi kind + | .ldecl index _ userName type value nonDep kind => do + --IO.println s!"[LD] {toString type} := {toString value}" + return .ldecl index fvarId userName (← translateExpr type) (← translateExpr value) nonDep kind + +partial def translateLCtx : MetaTranslateM LocalContext := do + resetFVarMap + (← getSourceLCtx).foldlM (λ lctx srcLocalDecl => do + let localDecl ← Meta.withLCtx lctx #[] do translateLocalDecl srcLocalDecl + pure $ lctx.addDecl localDecl + ) (← MonadLCtx.getLCtx) + +partial def translateMVarId (srcMVarId: MVarId) : MetaTranslateM MVarId := do + if let .some mvarId' := (← get).mvarMap.find? srcMVarId then + return mvarId' + let srcDecl := (← getSourceMCtx).findDecl? srcMVarId |>.get! + let mvar ← withTheReader Context (λ ctx => { ctx with sourceLCtx := srcDecl.lctx }) do + let lctx' ← translateLCtx + let localInstances' ← srcDecl.localInstances.mapM translateLocalInstance + Meta.withLCtx lctx' localInstances' do + let target' ← translateExpr srcDecl.type + Meta.mkFreshExprMVar target' srcDecl.kind srcDecl.userName + addTranslatedMVar srcMVarId mvar.mvarId! + return mvar.mvarId! +end + +def translateMVarFromTermInfo (termInfo : Elab.TermInfo) (context? : Option Elab.ContextInfo) + : MetaTranslateM MVarId := do + withTheReader Context (λ ctx => { ctx with + sourceMCtx := context?.map (·.mctx) |>.getD {}, + sourceLCtx := termInfo.lctx, + }) do + let type := termInfo.expectedType?.get! + let lctx' ← translateLCtx + let mvar ← Meta.withLCtx lctx' #[] do + let type' ← translateExpr type + Meta.mkFreshExprSyntheticOpaqueMVar type' + return mvar.mvarId! + + +def translateMVarFromTacticInfoBefore (tacticInfo : Elab.TacticInfo) (_context? : Option Elab.ContextInfo) + : MetaTranslateM (List MVarId) := do + withTheReader Context (λ ctx => { ctx with sourceMCtx := tacticInfo.mctxBefore }) do + tacticInfo.goalsBefore.mapM translateMVarId + + +end MetaTranslate + +export MetaTranslate (MetaTranslateM) + +end Pantograph.Frontend diff --git a/Test/Frontend.lean b/Test/Frontend.lean index c186503..b09ef81 100644 --- a/Test/Frontend.lean +++ b/Test/Frontend.lean @@ -26,8 +26,8 @@ theorem plus_n_Sm_proved_formal_sketch : ∀ n m : Nat, n + (m + 1) = (n + m) + sorry " let goalStates ← (collectSorrysFromSource sketch).run' {} - let [goalState] := goalStates | panic! "Illegal number of states" - addTest $ LSpec.check "plus_n_Sm" ((← goalState.serializeGoals (options := {})).map (·.devolatilize) = #[ + let [goalState] := goalStates | panic! "Incorrect number of states" + addTest $ LSpec.check "goals" ((← goalState.serializeGoals (options := {})).map (·.devolatilize) = #[ { target := { pp? := "∀ (n m : Nat), n = m" }, vars := #[ @@ -49,8 +49,8 @@ example : ∀ (n m: Nat), n + m = m + n := by sorry " let goalStates ← (collectSorrysFromSource sketch).run' {} - let [goalState] := goalStates | panic! s!"Illegal number of states: {goalStates.length}" - addTest $ LSpec.check "plus_n_Sm" ((← goalState.serializeGoals (options := {})).map (·.devolatilize) = #[ + let [goalState] := goalStates | panic! s!"Incorrect number of states: {goalStates.length}" + addTest $ LSpec.check "goals" ((← goalState.serializeGoals (options := {})).map (·.devolatilize) = #[ { target := { pp? := "n + m = m + n" }, vars := #[{ @@ -77,8 +77,8 @@ example : ∀ (n m: Nat), n + m = m + n := by sorry " let goalStates ← (collectSorrysFromSource sketch).run' {} - let [goalState] := goalStates | panic! s!"Illegal number of states: {goalStates.length}" - addTest $ LSpec.check "plus_n_Sm" ((← goalState.serializeGoals (options := {})).map (·.devolatilize) = #[ + let [goalState] := goalStates | panic! s!"Incorrect number of states: {goalStates.length}" + addTest $ LSpec.check "goals" ((← goalState.serializeGoals (options := {})).map (·.devolatilize) = #[ { target := { pp? := "0 + m = m" }, vars := #[{ @@ -87,6 +87,7 @@ example : ∀ (n m: Nat), n + m = m + n := by }] }, { + userName? := .some "zero", target := { pp? := "0 + m = m + 0" }, vars := #[{ userName := "m", @@ -110,6 +111,7 @@ example : ∀ (n m: Nat), n + m = m + n := by }] }, { + userName? := .some "succ", target := { pp? := "n + 1 + m = m + (n + 1)" }, vars := #[{ userName := "m", @@ -127,12 +129,47 @@ example : ∀ (n m: Nat), n + m = m + n := by } ]) +def test_sorry_in_coupled: TestT MetaM Unit := do + let sketch := " +example : ∀ (y: Nat), ∃ (x: Nat), y + 1 = x := by + intro y + apply Exists.intro + case h => sorry + case w => sorry + " + let goalStates ← (collectSorrysFromSource sketch).run' {} + let [goalState] := goalStates | panic! s!"Incorrect number of states: {goalStates.length}" + addTest $ LSpec.check "goals" ((← goalState.serializeGoals (options := {})).map (·.devolatilize) = #[ + { + target := { pp? := "y + 1 = ?w" }, + vars := #[{ + userName := "y", + type? := .some { pp? := "Nat" }, + } + ], + }, + { + userName? := .some "w", + target := { pp? := "Nat" }, + vars := #[{ + userName := "y✝", + isInaccessible := true, + type? := .some { pp? := "Nat" }, + }, { + userName := "y", + type? := .some { pp? := "Nat" }, + } + ], + } + ]) + def suite (env : Environment): List (String × IO LSpec.TestSeq) := let tests := [ ("multiple_sorrys_in_proof", test_multiple_sorrys_in_proof), ("sorry_in_middle", test_sorry_in_middle), ("sorry_in_induction", test_sorry_in_induction), + ("sorry_in_coupled", test_sorry_in_coupled), ] tests.map (fun (name, test) => (name, runMetaMSeq env $ runTest test))