From b381d89ff93f5fa94ccc161ce70f5f22e3ac2dcb Mon Sep 17 00:00:00 2001 From: Leni Aniva Date: Fri, 27 Oct 2023 15:15:22 -0700 Subject: [PATCH] feat: Assigning a goal with an expression --- Pantograph/Goal.lean | 67 +++++++++++++++++++++++++++++++++++----- Pantograph/Protocol.lean | 3 +- Pantograph/Serial.lean | 5 +-- Test/Proofs.lean | 61 +++++++++++++++++++++++++++++------- 4 files changed, 112 insertions(+), 24 deletions(-) diff --git a/Pantograph/Goal.lean b/Pantograph/Goal.lean index 1f3f71a..3be34ad 100644 --- a/Pantograph/Goal.lean +++ b/Pantograph/Goal.lean @@ -47,13 +47,15 @@ protected def GoalState.runM {α: Type} (state: GoalState) (m: Elab.TermElabM α protected def GoalState.mctx (state: GoalState): MetavarContext := state.savedState.term.meta.meta.mctx +protected def GoalState.env (state: GoalState): Environment := + state.savedState.term.meta.core.env private def GoalState.mvars (state: GoalState): SSet MVarId := state.mctx.decls.foldl (init := .empty) fun acc k _ => acc.insert k /-- Inner function for executing tactic on goal state -/ def executeTactic (state: Elab.Tactic.SavedState) (goal: MVarId) (tactic: Syntax) : - M (Except (Array String) (Elab.Tactic.SavedState × List MVarId)):= do - let tacticM (stx: Syntax): Elab.Tactic.TacticM (Except (Array String) (Elab.Tactic.SavedState × List MVarId)) := do + M (Except (Array String) Elab.Tactic.SavedState):= do + let tacticM (stx: Syntax): Elab.Tactic.TacticM (Except (Array String) Elab.Tactic.SavedState) := do state.restore Elab.Tactic.setGoals [goal] try @@ -63,9 +65,7 @@ def executeTactic (state: Elab.Tactic.SavedState) (goal: MVarId) (tactic: Syntax let errors ← (messages.map Message.data).mapM fun md => md.toString return .error errors else - let unsolved ← Elab.Tactic.getUnsolvedGoals - -- The order of evaluation is important here, since `getUnsolvedGoals` prunes the goals set - return .ok (← MonadBacktrack.saveState, unsolved) + return .ok (← MonadBacktrack.saveState) catch exception => return .error #[← exception.toMessageData.toString] tacticM tactic { elaborator := .anonymous } |>.run' state.tactic @@ -97,8 +97,7 @@ protected def GoalState.execute (state: GoalState) (goalId: Nat) (tactic: String match (← executeTactic (state := state.savedState) (goal := goal) (tactic := tactic)) with | .error errors => return .failure errors - | .ok (nextSavedState, nextGoals) => - assert! nextSavedState.tactic.goals.length == nextGoals.length + | .ok nextSavedState => -- Assert that the definition of metavariables are the same let nextMCtx := nextSavedState.term.meta.meta.mctx let prevMCtx := state.savedState.term.meta.meta.mctx @@ -112,12 +111,64 @@ protected def GoalState.execute (state: GoalState) (goalId: Nat) (tactic: String return acc.insert mvarId ) SSet.empty return .success { + state with savedState := nextSavedState - root := state.root, newMVars, parentGoalId := goalId, } +protected def GoalState.tryAssign (state: GoalState) (goalId: Nat) (expr: String): M TacticResult := do + let goal ← match state.savedState.tactic.goals.get? goalId with + | .some goal => pure goal + | .none => return .indexError goalId + let expr ← match Parser.runParserCategory + (env := state.env) + (catName := `term) + (input := expr) + (fileName := "") with + | .ok syn => pure syn + | .error error => return .parseError error + let tacticM: Elab.Tactic.TacticM TacticResult := do + state.savedState.restore + Elab.Tactic.setGoals [goal] + try + let expr ← Elab.Term.elabTerm (stx := expr) (expectedType? := .none) + -- Attempt to unify the expression + let goalType ← goal.getType + let exprType ← Meta.inferType expr + if !(← Meta.isDefEq goalType exprType) then + return .failure #["Type unification failed", toString (← Meta.ppExpr goalType), toString (← Meta.ppExpr exprType)] + goal.checkNotAssigned `GoalState.tryAssign + goal.assign expr + if (← getThe Core.State).messages.hasErrors then + let messages := (← getThe Core.State).messages.getErrorMessages |>.toList.toArray + let errors ← (messages.map Message.data).mapM fun md => md.toString + return .failure errors + else + let prevMCtx := state.savedState.term.meta.meta.mctx + let nextMCtx ← getMCtx + -- Generate a list of mvarIds that exist in the parent state; Also test the + -- assertion that the types have not changed on any mvars. + let newMVars ← nextMCtx.decls.foldlM (fun acc mvarId mvarDecl => do + if let .some prevMVarDecl := prevMCtx.decls.find? mvarId then + assert! prevMVarDecl.type == mvarDecl.type + return acc + else + return mvarId :: acc + ) [] + -- The new goals are the newMVars that lack an assignment + Elab.Tactic.setGoals (← newMVars.filterM (λ mvar => do pure !(← mvar.isAssigned))) + let nextSavedState ← MonadBacktrack.saveState + return .success { + state with + savedState := nextSavedState, + newMVars := newMVars.toSSet, + parentGoalId := goalId, + } + catch exception => + return .failure #[← exception.toMessageData.toString] + tacticM { elaborator := .anonymous } |>.run' state.savedState.tactic + /-- After finishing one branch of a proof (`graftee`), pick up from the point where the proof was left off (`target`) -/ protected def GoalState.continue (target: GoalState) (graftee: GoalState): Except String GoalState := if target.root != graftee.root then diff --git a/Pantograph/Protocol.lean b/Pantograph/Protocol.lean index 1c05227..b0e7744 100644 --- a/Pantograph/Protocol.lean +++ b/Pantograph/Protocol.lean @@ -172,7 +172,8 @@ structure GoalPrint where printContext: Bool := true printValue: Bool := true printNewMVars: Bool := false - printNonVisible: Bool := false + -- Print all mvars + printAll: Bool := false end Pantograph.Protocol diff --git a/Pantograph/Serial.lean b/Pantograph/Serial.lean index 87552eb..1a07444 100644 --- a/Pantograph/Serial.lean +++ b/Pantograph/Serial.lean @@ -262,9 +262,6 @@ protected def GoalState.serializeGoals (state: GoalState) (parent: Option GoalSt let parentGoal := parentState.goals.get! state.parentGoalId parentState.mctx.findDecl? parentGoal) goals.mapM fun goal => do - if options.noRepeat then - let key := if parentDecl?.isSome then "is some" else "is none" - IO.println s!"goal: {goal.name}, {key}" match state.mctx.findDecl? goal with | .some mvarDecl => let serializedGoal ← serialize_goal options mvarDecl (parentDecl? := parentDecl?) @@ -296,7 +293,7 @@ protected def GoalState.print (goalState: GoalState) (options: Protocol.GoalPrin else if mvarId == goalState.root then printMVar (pref := ">") mvarId decl -- Print the remainig ones that users don't see in Lean - else if options.printNonVisible then + else if options.printAll then let pref := if goalState.newMVars.contains mvarId then "~" else " " printMVar pref mvarId decl else diff --git a/Test/Proofs.lean b/Test/Proofs.lean index 79f0f38..809cf50 100644 --- a/Test/Proofs.lean +++ b/Test/Proofs.lean @@ -66,8 +66,9 @@ def startProof (start: Start): TestM (Option GoalState) := do def assertUnreachable (message: String): LSpec.TestSeq := LSpec.check message false -def buildGoal (nameType: List (String × String)) (target: String): Protocol.Goal := +def buildGoal (nameType: List (String × String)) (target: String) (caseName?: Option String := .none): Protocol.Goal := { + caseName?, target := { pp? := .some target}, vars := (nameType.map fun x => ({ userName := x.fst, @@ -187,21 +188,21 @@ def proof_arith: TestM Unit := do addTest $ assertUnreachable $ other.toString return () addTest $ LSpec.check "intros" (state1.goals.length = 1) - addTest $ LSpec.test "1 root" state1.rootExpr.isNone + addTest $ LSpec.test "(1 root)" state1.rootExpr.isNone let state2 ← match ← state1.execute (goalId := 0) (tactic := "simp [Nat.add_assoc, Nat.add_comm, Nat.add_left_comm, Nat.mul_comm, Nat.mul_assoc, Nat.mul_left_comm] at *") with | .success state => pure state | other => do addTest $ assertUnreachable $ other.toString return () addTest $ LSpec.check "simp ..." (state2.goals.length = 1) - addTest $ LSpec.check "2 root" state2.rootExpr.isNone + addTest $ LSpec.check "(2 root)" state2.rootExpr.isNone let state3 ← match ← state2.execute (goalId := 0) (tactic := "assumption") with | .success state => pure state | other => do addTest $ assertUnreachable $ other.toString return () addTest $ LSpec.test "assumption" state3.goals.isEmpty - addTest $ LSpec.check "3 root" state3.rootExpr.isSome + addTest $ LSpec.check "(3 root)" state3.rootExpr.isSome return () -- Two ways to write the same theorem @@ -253,7 +254,7 @@ def proof_or_comm: TestM Unit := do | other => do addTest $ assertUnreachable $ other.toString return () - addTest $ LSpec.check "· assumption" state4_1.goals.isEmpty + addTest $ LSpec.check " assumption" state4_1.goals.isEmpty addTest $ LSpec.check "(4_1 root)" state4_1.rootExpr.isNone let state3_2 ← match ← state2.execute (goalId := 1) (tactic := "apply Or.inl") with | .success state => pure state @@ -266,7 +267,7 @@ def proof_or_comm: TestM Unit := do | other => do addTest $ assertUnreachable $ other.toString return () - addTest $ LSpec.check "· assumption" state4_2.goals.isEmpty + addTest $ LSpec.check " assumption" state4_2.goals.isEmpty addTest $ LSpec.check "(4_2 root)" state4_2.rootExpr.isNone -- Ensure the proof can continue from `state4_2`. let state2b ← match state2.continue state4_2 with @@ -286,8 +287,8 @@ def proof_or_comm: TestM Unit := do | other => do addTest $ assertUnreachable $ other.toString return () - addTest $ LSpec.check "· assumption" state4_1.goals.isEmpty - addTest $ LSpec.check "4_1 root" state4_1.rootExpr.isSome + addTest $ LSpec.check " assumption" state4_1.goals.isEmpty + addTest $ LSpec.check "(4_1 root)" state4_1.rootExpr.isSome return () where @@ -336,7 +337,45 @@ def proof_m_couple: TestM Unit := do addTest $ LSpec.test "(2 root)" state1b.rootExpr.isNone return () -/-- Tests the most basic form of proofs whose goals do not relate to each other -/ +def proof_proposition_generation: TestM Unit := do + let state? ← startProof (.expr "Σ' p:Prop, p") + let state0 ← match state? with + | .some state => pure state + | .none => do + addTest $ assertUnreachable "Goal could not parse" + return () + + let state1 ← match ← state0.execute (goalId := 0) (tactic := "apply PSigma.mk") with + | .success state => pure state + | other => do + addTest $ assertUnreachable $ other.toString + return () + addTest $ LSpec.check "apply PSigma.mk" ((← state1.serializeGoals (options := ← read)).map (·.devolatilize) = + #[ + buildGoal [] "?fst" (caseName? := .some "snd"), + buildGoal [] "Prop" (caseName? := .some "fst") + ]) + addTest $ LSpec.test "(1 root)" state1.rootExpr.isNone + + let state2 ← match ← state1.tryAssign (goalId := 0) (expr := "λ (x: Nat) => _") with + | .success state => pure state + | other => do + addTest $ assertUnreachable $ other.toString + return () + addTest $ LSpec.check ":= λ (x: Nat), _" ((← state2.serializeGoals (options := ← read)).map (·.target.pp?) = + #[.some "Nat → Prop", .some "∀ (x : Nat), ?m.29 x"]) + addTest $ LSpec.test "(2 root)" state2.rootExpr.isNone + + let state3 ← match ← state2.tryAssign (goalId := 1) (expr := "fun x => Eq.refl x") with + | .success state => pure state + | other => do + addTest $ assertUnreachable $ other.toString + return () + addTest $ LSpec.check ":= Eq.refl" ((← state3.serializeGoals (options := ← read)).map (·.target.pp?) = + #[]) + addTest $ LSpec.test "(3 root)" state3.rootExpr.isSome + return () + def suite: IO LSpec.TestSeq := do let env: Lean.Environment ← Lean.importModules (imports := #[{ module := Name.append .anonymous "Init", runtimeOnly := false}]) @@ -348,8 +387,8 @@ def suite: IO LSpec.TestSeq := do ("Nat.add_comm delta", proof_delta_variable), ("arithmetic", proof_arith), ("Or.comm", proof_or_comm), - ("2 < 5", proof_m_couple) - --("delta variable", proof_delta_variable) + ("2 < 5", proof_m_couple), + ("Proposition Generation", proof_proposition_generation) ] let tests ← tests.foldlM (fun acc tests => do let (name, tests) := tests