refactor: Into CompilationUnit objects

This commit is contained in:
Leni Aniva 2024-12-11 16:25:45 -08:00
parent d3c14f321f
commit 47b2fbe38d
Signed by: aniva
GPG Key ID: 4D9B1C8D10EA4C50
4 changed files with 88 additions and 53 deletions

View File

@ -1,21 +0,0 @@
from dataclasses import dataclass
@dataclass(frozen=True)
class TacticInvocation:
"""
One tactic invocation with the before/after goals extracted from Lean source
code.
"""
before: str
after: str
tactic: str
used_constants: list[str]
@staticmethod
def parse(payload: dict):
return TacticInvocation(
before=payload["goalBefore"],
after=payload["goalAfter"],
tactic=payload["tactic"],
used_constants=payload.get('usedConstants', []),
)

69
pantograph/data.py Normal file
View File

@ -0,0 +1,69 @@
from typing import Optional, Tuple
from dataclasses import dataclass, field
from pantograph.expr import GoalState
@dataclass(frozen=True)
class TacticInvocation:
"""
One tactic invocation with the before/after goals extracted from Lean source
code.
"""
before: str
after: str
tactic: str
used_constants: list[str]
@staticmethod
def parse(payload: dict):
return TacticInvocation(
before=payload["goalBefore"],
after=payload["goalAfter"],
tactic=payload["tactic"],
used_constants=payload.get('usedConstants', []),
)
@dataclass(frozen=True)
class CompilationUnit:
i_begin: int
i_end: int
messages: list[str] = field(default_factory=list)
invocations: Optional[list[TacticInvocation]] = None
# If `goal_state` is none, maybe error has occurred. See `messages`
goal_state: Optional[GoalState] = None
goal_src_boundaries: Optional[list[Tuple[int, int]]] = None
new_constants: Optional[list[str]] = None
@staticmethod
def parse(payload: dict, goal_state_sentinel=None):
i_begin = payload["boundary"][0]
i_end = payload["boundary"][1]
messages = payload["messages"]
if (invocation_payload := payload.get("invocations")) is not None:
invocations = [
TacticInvocation.parse(i) for i in invocation_payload
]
else:
invocations = None
if (state_id := payload.get("goalStateId")) is not None:
goal_state = GoalState.parse_inner(int(state_id), payload["goals"], goal_state_sentinel)
goal_src_boundaries = payload["goalSrcBoundaries"]
else:
goal_state = None
goal_src_boundaries = None
new_constants = payload.get("newConstants")
return CompilationUnit(
i_begin,
i_end,
messages,
invocations,
goal_state,
goal_src_boundaries,
new_constants
)

View File

@ -98,6 +98,7 @@ class GoalState:
return GoalState(state_id, goals, _sentinel)
@staticmethod
def parse(payload: dict, _sentinel: list[int]):
assert _sentinel is not None
return GoalState.parse_inner(payload["nextStateId"], payload["goals"], _sentinel)
def __str__(self):

View File

@ -17,7 +17,7 @@ from pantograph.expr import (
TacticCalc,
TacticExpr,
)
from pantograph.compiler import TacticInvocation
from pantograph.data import CompilationUnit
def _get_proc_cwd():
return Path(__file__).parent
@ -173,7 +173,11 @@ class Server:
if "error" in result:
print(f"Cannot start goal: {expr}")
raise ServerError(result["desc"])
return GoalState(state_id=result["stateId"], goals=[Goal.sentence(expr)], _sentinel=self.to_remove_goal_states)
return GoalState(
state_id=result["stateId"],
goals=[Goal.sentence(expr)],
_sentinel=self.to_remove_goal_states,
)
def goal_tactic(self, state: GoalState, goal_id: int, tactic: Tactic) -> GoalState:
"""
@ -231,7 +235,7 @@ class Server:
raise ServerError(result["parseError"])
return GoalState.parse(result, self.to_remove_goal_states)
def tactic_invocations(self, file_name: Union[str, Path]) -> tuple[list[str], list[TacticInvocation]]:
def tactic_invocations(self, file_name: Union[str, Path]) -> list[CompilationUnit]:
"""
Collect tactic invocation points in file, and return them.
"""
@ -244,26 +248,16 @@ class Server:
if "error" in result:
raise ServerError(result["desc"])
with open(file_name, 'rb') as f:
content = f.read()
units = [
content[unit["boundary"][0]:unit["boundary"][1]].decode('utf-8')
for unit in result['units']
]
invocations = [
invocation
for unit in result['units']
for invocation in [TacticInvocation.parse(i) for i in unit['invocations']]
]
return units, invocations
units = [CompilationUnit.parse(payload) for payload in result['units']]
return units
def load_sorry(self, command: str) -> list[GoalState | list[str]]:
def load_sorry(self, content: str) -> list[CompilationUnit]:
"""
Executes the compiler on a Lean file. For each compilation unit, either
return the gathered `sorry` s, or a list of messages indicating error.
"""
result = self.run('frontend.process', {
'file': command,
'file': content,
'invocations': False,
"sorrys": True,
"newConstants": False,
@ -271,19 +265,11 @@ class Server:
if "error" in result:
raise ServerError(result["desc"])
def parse_unit(unit: dict):
state_id = unit.get("goalStateId")
if state_id is None:
# NOTE: `state_id` maybe 0.
# Maybe error has occurred
return unit["messages"]
state = GoalState.parse_inner(state_id, unit["goals"], self.to_remove_goal_states)
return state
states = [
parse_unit(unit) for unit in result['units']
units = [
CompilationUnit.parse(payload, goal_state_sentinel=self.to_remove_goal_states)
for payload in result['units']
]
return states
return units
def get_version():
@ -447,9 +433,9 @@ class TestServer(unittest.TestCase):
def test_load_sorry(self):
server = Server()
state0, = server.load_sorry("example (p: Prop): p → p := sorry")
if isinstance(state0, list):
print(state0)
unit, = server.load_sorry("example (p: Prop): p → p := sorry")
self.assertIsNotNone(unit.goal_state, f"{unit.messages}")
state0 = unit.goal_state
self.assertEqual(state0.goals, [
Goal(
[Variable(name="p", t="Prop")],