refactor: Into CompilationUnit objects

This commit is contained in:
Leni Aniva 2024-12-11 16:25:45 -08:00
parent d3c14f321f
commit 47b2fbe38d
Signed by: aniva
GPG Key ID: 4D9B1C8D10EA4C50
4 changed files with 88 additions and 53 deletions

View File

@ -1,21 +0,0 @@
from dataclasses import dataclass
@dataclass(frozen=True)
class TacticInvocation:
"""
One tactic invocation with the before/after goals extracted from Lean source
code.
"""
before: str
after: str
tactic: str
used_constants: list[str]
@staticmethod
def parse(payload: dict):
return TacticInvocation(
before=payload["goalBefore"],
after=payload["goalAfter"],
tactic=payload["tactic"],
used_constants=payload.get('usedConstants', []),
)

69
pantograph/data.py Normal file
View File

@ -0,0 +1,69 @@
from typing import Optional, Tuple
from dataclasses import dataclass, field
from pantograph.expr import GoalState
@dataclass(frozen=True)
class TacticInvocation:
"""
One tactic invocation with the before/after goals extracted from Lean source
code.
"""
before: str
after: str
tactic: str
used_constants: list[str]
@staticmethod
def parse(payload: dict):
return TacticInvocation(
before=payload["goalBefore"],
after=payload["goalAfter"],
tactic=payload["tactic"],
used_constants=payload.get('usedConstants', []),
)
@dataclass(frozen=True)
class CompilationUnit:
i_begin: int
i_end: int
messages: list[str] = field(default_factory=list)
invocations: Optional[list[TacticInvocation]] = None
# If `goal_state` is none, maybe error has occurred. See `messages`
goal_state: Optional[GoalState] = None
goal_src_boundaries: Optional[list[Tuple[int, int]]] = None
new_constants: Optional[list[str]] = None
@staticmethod
def parse(payload: dict, goal_state_sentinel=None):
i_begin = payload["boundary"][0]
i_end = payload["boundary"][1]
messages = payload["messages"]
if (invocation_payload := payload.get("invocations")) is not None:
invocations = [
TacticInvocation.parse(i) for i in invocation_payload
]
else:
invocations = None
if (state_id := payload.get("goalStateId")) is not None:
goal_state = GoalState.parse_inner(int(state_id), payload["goals"], goal_state_sentinel)
goal_src_boundaries = payload["goalSrcBoundaries"]
else:
goal_state = None
goal_src_boundaries = None
new_constants = payload.get("newConstants")
return CompilationUnit(
i_begin,
i_end,
messages,
invocations,
goal_state,
goal_src_boundaries,
new_constants
)

View File

@ -98,6 +98,7 @@ class GoalState:
return GoalState(state_id, goals, _sentinel) return GoalState(state_id, goals, _sentinel)
@staticmethod @staticmethod
def parse(payload: dict, _sentinel: list[int]): def parse(payload: dict, _sentinel: list[int]):
assert _sentinel is not None
return GoalState.parse_inner(payload["nextStateId"], payload["goals"], _sentinel) return GoalState.parse_inner(payload["nextStateId"], payload["goals"], _sentinel)
def __str__(self): def __str__(self):

View File

@ -17,7 +17,7 @@ from pantograph.expr import (
TacticCalc, TacticCalc,
TacticExpr, TacticExpr,
) )
from pantograph.compiler import TacticInvocation from pantograph.data import CompilationUnit
def _get_proc_cwd(): def _get_proc_cwd():
return Path(__file__).parent return Path(__file__).parent
@ -173,7 +173,11 @@ class Server:
if "error" in result: if "error" in result:
print(f"Cannot start goal: {expr}") print(f"Cannot start goal: {expr}")
raise ServerError(result["desc"]) raise ServerError(result["desc"])
return GoalState(state_id=result["stateId"], goals=[Goal.sentence(expr)], _sentinel=self.to_remove_goal_states) return GoalState(
state_id=result["stateId"],
goals=[Goal.sentence(expr)],
_sentinel=self.to_remove_goal_states,
)
def goal_tactic(self, state: GoalState, goal_id: int, tactic: Tactic) -> GoalState: def goal_tactic(self, state: GoalState, goal_id: int, tactic: Tactic) -> GoalState:
""" """
@ -231,7 +235,7 @@ class Server:
raise ServerError(result["parseError"]) raise ServerError(result["parseError"])
return GoalState.parse(result, self.to_remove_goal_states) return GoalState.parse(result, self.to_remove_goal_states)
def tactic_invocations(self, file_name: Union[str, Path]) -> tuple[list[str], list[TacticInvocation]]: def tactic_invocations(self, file_name: Union[str, Path]) -> list[CompilationUnit]:
""" """
Collect tactic invocation points in file, and return them. Collect tactic invocation points in file, and return them.
""" """
@ -244,26 +248,16 @@ class Server:
if "error" in result: if "error" in result:
raise ServerError(result["desc"]) raise ServerError(result["desc"])
with open(file_name, 'rb') as f: units = [CompilationUnit.parse(payload) for payload in result['units']]
content = f.read() return units
units = [
content[unit["boundary"][0]:unit["boundary"][1]].decode('utf-8')
for unit in result['units']
]
invocations = [
invocation
for unit in result['units']
for invocation in [TacticInvocation.parse(i) for i in unit['invocations']]
]
return units, invocations
def load_sorry(self, command: str) -> list[GoalState | list[str]]: def load_sorry(self, content: str) -> list[CompilationUnit]:
""" """
Executes the compiler on a Lean file. For each compilation unit, either Executes the compiler on a Lean file. For each compilation unit, either
return the gathered `sorry` s, or a list of messages indicating error. return the gathered `sorry` s, or a list of messages indicating error.
""" """
result = self.run('frontend.process', { result = self.run('frontend.process', {
'file': command, 'file': content,
'invocations': False, 'invocations': False,
"sorrys": True, "sorrys": True,
"newConstants": False, "newConstants": False,
@ -271,19 +265,11 @@ class Server:
if "error" in result: if "error" in result:
raise ServerError(result["desc"]) raise ServerError(result["desc"])
def parse_unit(unit: dict): units = [
state_id = unit.get("goalStateId") CompilationUnit.parse(payload, goal_state_sentinel=self.to_remove_goal_states)
if state_id is None: for payload in result['units']
# NOTE: `state_id` maybe 0.
# Maybe error has occurred
return unit["messages"]
state = GoalState.parse_inner(state_id, unit["goals"], self.to_remove_goal_states)
return state
states = [
parse_unit(unit) for unit in result['units']
] ]
return states return units
def get_version(): def get_version():
@ -447,9 +433,9 @@ class TestServer(unittest.TestCase):
def test_load_sorry(self): def test_load_sorry(self):
server = Server() server = Server()
state0, = server.load_sorry("example (p: Prop): p → p := sorry") unit, = server.load_sorry("example (p: Prop): p → p := sorry")
if isinstance(state0, list): self.assertIsNotNone(unit.goal_state, f"{unit.messages}")
print(state0) state0 = unit.goal_state
self.assertEqual(state0.goals, [ self.assertEqual(state0.goals, [
Goal( Goal(
[Variable(name="p", t="Prop")], [Variable(name="p", t="Prop")],