feat: Goal state garbage collection
This commit is contained in:
parent
44f982d689
commit
7c3b64562b
|
@ -60,6 +60,11 @@ class GoalState:
|
||||||
state_id: int
|
state_id: int
|
||||||
goals: list[Goal]
|
goals: list[Goal]
|
||||||
|
|
||||||
|
_sentinel: list[int]
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
self._sentinel.append(self.state_id)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_solved(self) -> bool:
|
def is_solved(self) -> bool:
|
||||||
"""
|
"""
|
||||||
|
@ -68,10 +73,10 @@ class GoalState:
|
||||||
return not self.goals
|
return not self.goals
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse(payload: dict) -> Self:
|
def parse(payload: dict, _sentinel: list[int]) -> Self:
|
||||||
state_id = payload["nextStateId"]
|
state_id = payload["nextStateId"]
|
||||||
goals = [Goal.parse(g) for g in payload["goals"]]
|
goals = [Goal.parse(g) for g in payload["goals"]]
|
||||||
return GoalState(state_id, goals)
|
return GoalState(state_id, goals, _sentinel)
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class TacticHave:
|
class TacticHave:
|
||||||
|
|
|
@ -36,6 +36,9 @@ class Server:
|
||||||
self.proc = None
|
self.proc = None
|
||||||
self.restart()
|
self.restart()
|
||||||
|
|
||||||
|
# List of goal states that should be garbage collected
|
||||||
|
self.to_remove_goal_states = []
|
||||||
|
|
||||||
def restart(self):
|
def restart(self):
|
||||||
if self.proc is not None:
|
if self.proc is not None:
|
||||||
self.proc.close()
|
self.proc.close()
|
||||||
|
@ -48,14 +51,26 @@ class Server:
|
||||||
self.proc.setecho(False)
|
self.proc.setecho(False)
|
||||||
|
|
||||||
def run(self, cmd, payload):
|
def run(self, cmd, payload):
|
||||||
|
"""
|
||||||
|
Runs a raw JSON command. Preferably use one of the commands below.
|
||||||
|
"""
|
||||||
s = json.dumps(payload)
|
s = json.dumps(payload)
|
||||||
self.proc.sendline(f"{cmd} {s}")
|
self.proc.sendline(f"{cmd} {s}")
|
||||||
try:
|
try:
|
||||||
self.proc.expect("{.*}\r\n", timeout=self.timeout)
|
self.proc.expect("{.*}\r\n", timeout=self.timeout)
|
||||||
output = self.proc.match.group()
|
output = self.proc.match.group()
|
||||||
return json.loads(output)
|
return json.loads(output)
|
||||||
except pexpect.exceptions.TIMEOUT:
|
except pexpect.exceptions.TIMEOUT as exc:
|
||||||
raise pexpect.exceptions.TIMEOUT
|
raise exc
|
||||||
|
|
||||||
|
def gc(self):
|
||||||
|
"""
|
||||||
|
Garbage collect deleted goal states.
|
||||||
|
|
||||||
|
Must be called periodically.
|
||||||
|
"""
|
||||||
|
self.run('goal.delete', {'stateIds': self.to_remove_goal_states})
|
||||||
|
self.to_remove_goal_states.clear()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
return self.run("reset", {})
|
return self.run("reset", {})
|
||||||
|
@ -64,7 +79,7 @@ class Server:
|
||||||
result = self.run('goal.start', {"expr": str(expr)})
|
result = self.run('goal.start', {"expr": str(expr)})
|
||||||
if "error" in result:
|
if "error" in result:
|
||||||
raise ServerError(result["desc"])
|
raise ServerError(result["desc"])
|
||||||
return GoalState(state_id = result["stateId"], goals = [Goal.sentence(expr)])
|
return GoalState(state_id=result["stateId"], goals=[Goal.sentence(expr)], _sentinel=self.to_remove_goal_states)
|
||||||
|
|
||||||
def goal_tactic(self, state: GoalState, goal_id: int, tactic: Tactic) -> GoalState:
|
def goal_tactic(self, state: GoalState, goal_id: int, tactic: Tactic) -> GoalState:
|
||||||
args = {"stateId": state.state_id, "goalId": goal_id}
|
args = {"stateId": state.state_id, "goalId": goal_id}
|
||||||
|
@ -83,7 +98,7 @@ class Server:
|
||||||
raise ServerError(result["tacticErrors"])
|
raise ServerError(result["tacticErrors"])
|
||||||
if "parseError" in result:
|
if "parseError" in result:
|
||||||
raise ServerError(result["parseError"])
|
raise ServerError(result["parseError"])
|
||||||
return GoalState.parse(result)
|
return GoalState.parse(result, self.to_remove_goal_states)
|
||||||
|
|
||||||
def goal_conv_begin(self, state: GoalState, goal_id: int) -> GoalState:
|
def goal_conv_begin(self, state: GoalState, goal_id: int) -> GoalState:
|
||||||
result = self.run('goal.tactic', {"stateId": state.state_id, "goalId": goal_id, "conv": True})
|
result = self.run('goal.tactic', {"stateId": state.state_id, "goalId": goal_id, "conv": True})
|
||||||
|
@ -93,7 +108,7 @@ class Server:
|
||||||
raise ServerError(result["tacticErrors"])
|
raise ServerError(result["tacticErrors"])
|
||||||
if "parseError" in result:
|
if "parseError" in result:
|
||||||
raise ServerError(result["parseError"])
|
raise ServerError(result["parseError"])
|
||||||
return GoalState.parse(result)
|
return GoalState.parse(result, self.to_remove_goal_states)
|
||||||
|
|
||||||
def goal_conv_end(self, state: GoalState) -> GoalState:
|
def goal_conv_end(self, state: GoalState) -> GoalState:
|
||||||
result = self.run('goal.tactic', {"stateId": state.state_id, "goalId": 0, "conv": False})
|
result = self.run('goal.tactic', {"stateId": state.state_id, "goalId": 0, "conv": False})
|
||||||
|
@ -103,7 +118,7 @@ class Server:
|
||||||
raise ServerError(result["tacticErrors"])
|
raise ServerError(result["tacticErrors"])
|
||||||
if "parseError" in result:
|
if "parseError" in result:
|
||||||
raise ServerError(result["parseError"])
|
raise ServerError(result["parseError"])
|
||||||
return GoalState.parse(result)
|
return GoalState.parse(result, self.to_remove_goal_states)
|
||||||
|
|
||||||
|
|
||||||
def get_version():
|
def get_version():
|
||||||
|
@ -122,6 +137,7 @@ class TestServer(unittest.TestCase):
|
||||||
def test_goal_start(self):
|
def test_goal_start(self):
|
||||||
server = Server()
|
server = Server()
|
||||||
state0 = server.goal_start("forall (p q: Prop), Or p q -> Or q p")
|
state0 = server.goal_start("forall (p q: Prop), Or p q -> Or q p")
|
||||||
|
self.assertEqual(len(server.to_remove_goal_states), 0)
|
||||||
self.assertEqual(state0.state_id, 0)
|
self.assertEqual(state0.state_id, 0)
|
||||||
state1 = server.goal_tactic(state0, goal_id=0, tactic="intro a")
|
state1 = server.goal_tactic(state0, goal_id=0, tactic="intro a")
|
||||||
self.assertEqual(state1.state_id, 1)
|
self.assertEqual(state1.state_id, 1)
|
||||||
|
@ -132,6 +148,17 @@ class TestServer(unittest.TestCase):
|
||||||
)])
|
)])
|
||||||
self.assertEqual(str(state1.goals[0]),"a : Prop\n⊢ ∀ (q : Prop), a ∨ q → q ∨ a")
|
self.assertEqual(str(state1.goals[0]),"a : Prop\n⊢ ∀ (q : Prop), a ∨ q → q ∨ a")
|
||||||
|
|
||||||
|
del state0
|
||||||
|
self.assertEqual(len(server.to_remove_goal_states), 1)
|
||||||
|
server.gc()
|
||||||
|
self.assertEqual(len(server.to_remove_goal_states), 0)
|
||||||
|
|
||||||
|
state0b = server.goal_start("forall (p: Prop), p -> p")
|
||||||
|
del state0b
|
||||||
|
self.assertEqual(len(server.to_remove_goal_states), 1)
|
||||||
|
server.gc()
|
||||||
|
self.assertEqual(len(server.to_remove_goal_states), 0)
|
||||||
|
|
||||||
def test_conv_calc(self):
|
def test_conv_calc(self):
|
||||||
server = Server()
|
server = Server()
|
||||||
state0 = server.goal_start("∀ (a b: Nat), (b = 2) -> 1 + a + 1 = a + b")
|
state0 = server.goal_start("∀ (a b: Nat), (b = 2) -> 1 + a + 1 = a + b")
|
||||||
|
|
Loading…
Reference in New Issue