refactor: Into CompilationUnit objects
This commit is contained in:
parent
d3c14f321f
commit
47b2fbe38d
|
@ -1,21 +0,0 @@
|
|||
from dataclasses import dataclass
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TacticInvocation:
|
||||
"""
|
||||
One tactic invocation with the before/after goals extracted from Lean source
|
||||
code.
|
||||
"""
|
||||
before: str
|
||||
after: str
|
||||
tactic: str
|
||||
used_constants: list[str]
|
||||
|
||||
@staticmethod
|
||||
def parse(payload: dict):
|
||||
return TacticInvocation(
|
||||
before=payload["goalBefore"],
|
||||
after=payload["goalAfter"],
|
||||
tactic=payload["tactic"],
|
||||
used_constants=payload.get('usedConstants', []),
|
||||
)
|
|
@ -0,0 +1,69 @@
|
|||
from typing import Optional, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from pantograph.expr import GoalState
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TacticInvocation:
|
||||
"""
|
||||
One tactic invocation with the before/after goals extracted from Lean source
|
||||
code.
|
||||
"""
|
||||
before: str
|
||||
after: str
|
||||
tactic: str
|
||||
used_constants: list[str]
|
||||
|
||||
@staticmethod
|
||||
def parse(payload: dict):
|
||||
return TacticInvocation(
|
||||
before=payload["goalBefore"],
|
||||
after=payload["goalAfter"],
|
||||
tactic=payload["tactic"],
|
||||
used_constants=payload.get('usedConstants', []),
|
||||
)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CompilationUnit:
|
||||
|
||||
i_begin: int
|
||||
i_end: int
|
||||
messages: list[str] = field(default_factory=list)
|
||||
|
||||
invocations: Optional[list[TacticInvocation]] = None
|
||||
# If `goal_state` is none, maybe error has occurred. See `messages`
|
||||
goal_state: Optional[GoalState] = None
|
||||
goal_src_boundaries: Optional[list[Tuple[int, int]]] = None
|
||||
|
||||
new_constants: Optional[list[str]] = None
|
||||
|
||||
@staticmethod
|
||||
def parse(payload: dict, goal_state_sentinel=None):
|
||||
i_begin = payload["boundary"][0]
|
||||
i_end = payload["boundary"][1]
|
||||
messages = payload["messages"]
|
||||
|
||||
if (invocation_payload := payload.get("invocations")) is not None:
|
||||
invocations = [
|
||||
TacticInvocation.parse(i) for i in invocation_payload
|
||||
]
|
||||
else:
|
||||
invocations = None
|
||||
|
||||
if (state_id := payload.get("goalStateId")) is not None:
|
||||
goal_state = GoalState.parse_inner(int(state_id), payload["goals"], goal_state_sentinel)
|
||||
goal_src_boundaries = payload["goalSrcBoundaries"]
|
||||
else:
|
||||
goal_state = None
|
||||
goal_src_boundaries = None
|
||||
|
||||
new_constants = payload.get("newConstants")
|
||||
|
||||
return CompilationUnit(
|
||||
i_begin,
|
||||
i_end,
|
||||
messages,
|
||||
invocations,
|
||||
goal_state,
|
||||
goal_src_boundaries,
|
||||
new_constants
|
||||
)
|
|
@ -98,6 +98,7 @@ class GoalState:
|
|||
return GoalState(state_id, goals, _sentinel)
|
||||
@staticmethod
|
||||
def parse(payload: dict, _sentinel: list[int]):
|
||||
assert _sentinel is not None
|
||||
return GoalState.parse_inner(payload["nextStateId"], payload["goals"], _sentinel)
|
||||
|
||||
def __str__(self):
|
||||
|
|
|
@ -17,7 +17,7 @@ from pantograph.expr import (
|
|||
TacticCalc,
|
||||
TacticExpr,
|
||||
)
|
||||
from pantograph.compiler import TacticInvocation
|
||||
from pantograph.data import CompilationUnit
|
||||
|
||||
def _get_proc_cwd():
|
||||
return Path(__file__).parent
|
||||
|
@ -173,7 +173,11 @@ class Server:
|
|||
if "error" in result:
|
||||
print(f"Cannot start goal: {expr}")
|
||||
raise ServerError(result["desc"])
|
||||
return GoalState(state_id=result["stateId"], goals=[Goal.sentence(expr)], _sentinel=self.to_remove_goal_states)
|
||||
return GoalState(
|
||||
state_id=result["stateId"],
|
||||
goals=[Goal.sentence(expr)],
|
||||
_sentinel=self.to_remove_goal_states,
|
||||
)
|
||||
|
||||
def goal_tactic(self, state: GoalState, goal_id: int, tactic: Tactic) -> GoalState:
|
||||
"""
|
||||
|
@ -231,7 +235,7 @@ class Server:
|
|||
raise ServerError(result["parseError"])
|
||||
return GoalState.parse(result, self.to_remove_goal_states)
|
||||
|
||||
def tactic_invocations(self, file_name: Union[str, Path]) -> tuple[list[str], list[TacticInvocation]]:
|
||||
def tactic_invocations(self, file_name: Union[str, Path]) -> list[CompilationUnit]:
|
||||
"""
|
||||
Collect tactic invocation points in file, and return them.
|
||||
"""
|
||||
|
@ -244,26 +248,16 @@ class Server:
|
|||
if "error" in result:
|
||||
raise ServerError(result["desc"])
|
||||
|
||||
with open(file_name, 'rb') as f:
|
||||
content = f.read()
|
||||
units = [
|
||||
content[unit["boundary"][0]:unit["boundary"][1]].decode('utf-8')
|
||||
for unit in result['units']
|
||||
]
|
||||
invocations = [
|
||||
invocation
|
||||
for unit in result['units']
|
||||
for invocation in [TacticInvocation.parse(i) for i in unit['invocations']]
|
||||
]
|
||||
return units, invocations
|
||||
units = [CompilationUnit.parse(payload) for payload in result['units']]
|
||||
return units
|
||||
|
||||
def load_sorry(self, command: str) -> list[GoalState | list[str]]:
|
||||
def load_sorry(self, content: str) -> list[CompilationUnit]:
|
||||
"""
|
||||
Executes the compiler on a Lean file. For each compilation unit, either
|
||||
return the gathered `sorry` s, or a list of messages indicating error.
|
||||
"""
|
||||
result = self.run('frontend.process', {
|
||||
'file': command,
|
||||
'file': content,
|
||||
'invocations': False,
|
||||
"sorrys": True,
|
||||
"newConstants": False,
|
||||
|
@ -271,19 +265,11 @@ class Server:
|
|||
if "error" in result:
|
||||
raise ServerError(result["desc"])
|
||||
|
||||
def parse_unit(unit: dict):
|
||||
state_id = unit.get("goalStateId")
|
||||
if state_id is None:
|
||||
# NOTE: `state_id` maybe 0.
|
||||
# Maybe error has occurred
|
||||
return unit["messages"]
|
||||
state = GoalState.parse_inner(state_id, unit["goals"], self.to_remove_goal_states)
|
||||
return state
|
||||
states = [
|
||||
parse_unit(unit) for unit in result['units']
|
||||
units = [
|
||||
CompilationUnit.parse(payload, goal_state_sentinel=self.to_remove_goal_states)
|
||||
for payload in result['units']
|
||||
]
|
||||
return states
|
||||
|
||||
return units
|
||||
|
||||
|
||||
def get_version():
|
||||
|
@ -447,9 +433,9 @@ class TestServer(unittest.TestCase):
|
|||
|
||||
def test_load_sorry(self):
|
||||
server = Server()
|
||||
state0, = server.load_sorry("example (p: Prop): p → p := sorry")
|
||||
if isinstance(state0, list):
|
||||
print(state0)
|
||||
unit, = server.load_sorry("example (p: Prop): p → p := sorry")
|
||||
self.assertIsNotNone(unit.goal_state, f"{unit.messages}")
|
||||
state0 = unit.goal_state
|
||||
self.assertEqual(state0.goals, [
|
||||
Goal(
|
||||
[Variable(name="p", t="Prop")],
|
||||
|
|
Loading…
Reference in New Issue