From 83fcec5d609580a5f46c8dfcbc7c07eda5482c6c Mon Sep 17 00:00:00 2001 From: Leni Aniva Date: Mon, 22 Apr 2024 13:11:28 -0700 Subject: [PATCH] feat: Add tactic object --- pantograph/expr.py | 11 ++++++++++- pantograph/server.py | 14 +++++++++----- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/pantograph/expr.py b/pantograph/expr.py index bf80d8c..300ab82 100644 --- a/pantograph/expr.py +++ b/pantograph/expr.py @@ -2,7 +2,7 @@ Data structuers for expressions and goals """ from dataclasses import dataclass -from typing import Optional, Self +from typing import Optional, Self, Union Expr = str @@ -60,3 +60,12 @@ class GoalState: @property def is_solved(self) -> bool: return not self.goals + +@dataclass(frozen=True) +class TacticNormal: + payload: str +@dataclass(frozen=True) +class TacticHave: + branch: str + +Tactic = Union[TacticNormal, TacticHave] diff --git a/pantograph/server.py b/pantograph/server.py index eb54f6d..85bc436 100644 --- a/pantograph/server.py +++ b/pantograph/server.py @@ -3,7 +3,7 @@ Class which manages a Pantograph instance. All calls to the kernel uses this interface. """ import json, pexpect, pathlib, unittest -from pantograph.expr import Variable, Goal, GoalState +from pantograph.expr import Variable, Goal, GoalState, Tactic, TacticNormal def _get_proc_cwd(): return pathlib.Path(__file__).parent @@ -65,9 +65,13 @@ class Server: raise ServerError(result["desc"]) return GoalState(state_id = result["stateId"], goals = [Goal.sentence(expr)]) - def goal_tactic(self, state: GoalState, goalId: int, tactic: str) -> GoalState: - result = self.run('goal.tactic', { - "stateId": state.state_id, "goalId": goalId, "tactic": tactic}) + def goal_tactic(self, state: GoalState, goalId: int, tactic: Tactic) -> GoalState: + args = { "stateId": state.state_id, "goalId": goalId } + if isinstance(tactic, TacticNormal): + args["tactic"] = tactic.payload + else: + raise Exception(f"Invalid tactic type: {tactic}") + result = self.run('goal.tactic', args) if "error" in result: raise ServerError(result["desc"]) if "tacticErrors" in result: @@ -95,7 +99,7 @@ class TestServer(unittest.TestCase): server = Server() state0 = server.goal_start("forall (p q: Prop), Or p q -> Or q p") self.assertEqual(state0.state_id, 0) - state1 = server.goal_tactic(state0, goalId=0, tactic="intro a") + state1 = server.goal_tactic(state0, goalId=0, tactic=TacticNormal("intro a")) self.assertEqual(state1.state_id, 1) self.assertEqual(state1.goals, [Goal( variables=[Variable(name="a", t="Prop")],