diff --git a/pantograph/expr.py b/pantograph/expr.py index 4ae0367..dee9711 100644 --- a/pantograph/expr.py +++ b/pantograph/expr.py @@ -60,6 +60,11 @@ class GoalState: state_id: int goals: list[Goal] + _sentinel: list[int] + + def __del__(self): + self._sentinel.append(self.state_id) + @property def is_solved(self) -> bool: """ @@ -68,10 +73,10 @@ class GoalState: return not self.goals @staticmethod - def parse(payload: dict) -> Self: + def parse(payload: dict, _sentinel: list[int]) -> Self: state_id = payload["nextStateId"] goals = [Goal.parse(g) for g in payload["goals"]] - return GoalState(state_id, goals) + return GoalState(state_id, goals, _sentinel) @dataclass(frozen=True) class TacticHave: diff --git a/pantograph/server.py b/pantograph/server.py index 52a7661..19f087f 100644 --- a/pantograph/server.py +++ b/pantograph/server.py @@ -36,6 +36,9 @@ class Server: self.proc = None self.restart() + # List of goal states that should be garbage collected + self.to_remove_goal_states = [] + def restart(self): if self.proc is not None: self.proc.close() @@ -48,14 +51,26 @@ class Server: self.proc.setecho(False) def run(self, cmd, payload): + """ + Runs a raw JSON command. Preferably use one of the commands below. + """ s = json.dumps(payload) self.proc.sendline(f"{cmd} {s}") try: self.proc.expect("{.*}\r\n", timeout=self.timeout) output = self.proc.match.group() return json.loads(output) - except pexpect.exceptions.TIMEOUT: - raise pexpect.exceptions.TIMEOUT + except pexpect.exceptions.TIMEOUT as exc: + raise exc + + def gc(self): + """ + Garbage collect deleted goal states. + + Must be called periodically. + """ + self.run('goal.delete', {'stateIds': self.to_remove_goal_states}) + self.to_remove_goal_states.clear() def reset(self): return self.run("reset", {}) @@ -64,7 +79,7 @@ class Server: result = self.run('goal.start', {"expr": str(expr)}) if "error" in result: raise ServerError(result["desc"]) - return GoalState(state_id = result["stateId"], goals = [Goal.sentence(expr)]) + return GoalState(state_id=result["stateId"], goals=[Goal.sentence(expr)], _sentinel=self.to_remove_goal_states) def goal_tactic(self, state: GoalState, goal_id: int, tactic: Tactic) -> GoalState: args = {"stateId": state.state_id, "goalId": goal_id} @@ -83,7 +98,7 @@ class Server: raise ServerError(result["tacticErrors"]) if "parseError" in result: raise ServerError(result["parseError"]) - return GoalState.parse(result) + return GoalState.parse(result, self.to_remove_goal_states) def goal_conv_begin(self, state: GoalState, goal_id: int) -> GoalState: result = self.run('goal.tactic', {"stateId": state.state_id, "goalId": goal_id, "conv": True}) @@ -93,7 +108,7 @@ class Server: raise ServerError(result["tacticErrors"]) if "parseError" in result: raise ServerError(result["parseError"]) - return GoalState.parse(result) + return GoalState.parse(result, self.to_remove_goal_states) def goal_conv_end(self, state: GoalState) -> GoalState: result = self.run('goal.tactic', {"stateId": state.state_id, "goalId": 0, "conv": False}) @@ -103,7 +118,7 @@ class Server: raise ServerError(result["tacticErrors"]) if "parseError" in result: raise ServerError(result["parseError"]) - return GoalState.parse(result) + return GoalState.parse(result, self.to_remove_goal_states) def get_version(): @@ -122,6 +137,7 @@ class TestServer(unittest.TestCase): def test_goal_start(self): server = Server() state0 = server.goal_start("forall (p q: Prop), Or p q -> Or q p") + self.assertEqual(len(server.to_remove_goal_states), 0) self.assertEqual(state0.state_id, 0) state1 = server.goal_tactic(state0, goal_id=0, tactic="intro a") self.assertEqual(state1.state_id, 1) @@ -132,6 +148,17 @@ class TestServer(unittest.TestCase): )]) self.assertEqual(str(state1.goals[0]),"a : Prop\n⊢ ∀ (q : Prop), a ∨ q → q ∨ a") + del state0 + self.assertEqual(len(server.to_remove_goal_states), 1) + server.gc() + self.assertEqual(len(server.to_remove_goal_states), 0) + + state0b = server.goal_start("forall (p: Prop), p -> p") + del state0b + self.assertEqual(len(server.to_remove_goal_states), 1) + server.gc() + self.assertEqual(len(server.to_remove_goal_states), 0) + def test_conv_calc(self): server = Server() state0 = server.goal_start("∀ (a b: Nat), (b = 2) -> 1 + a + 1 = a + b")