2024-10-04 21:55:47 -07:00
|
|
|
import json
|
|
|
|
from typing import Union, Optional
|
|
|
|
from dataclasses import dataclass
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class Datum:
|
|
|
|
"""
|
|
|
|
Represents one theorem proving datapoint.
|
|
|
|
"""
|
|
|
|
|
|
|
|
id: Optional[str] = None
|
|
|
|
|
|
|
|
# Problem and solution in natural language
|
|
|
|
nl_problem: Optional[Union[str, list[str]]] = None
|
|
|
|
nl_solution: Optional[Union[str, list[str]]] = None
|
|
|
|
|
|
|
|
# Problem in formal language
|
|
|
|
fl_problem: Optional[str] = None
|
|
|
|
|
|
|
|
def __str__(self):
|
|
|
|
if self.id:
|
|
|
|
return self.id
|
2024-10-05 01:23:02 -07:00
|
|
|
return self.nl_problem_str
|
2024-10-04 21:55:47 -07:00
|
|
|
|
|
|
|
@property
|
|
|
|
def nl_problem_str(self) -> Optional[str]:
|
|
|
|
if not self.nl_problem:
|
|
|
|
return None
|
|
|
|
if isinstance(self.nl_problem, list):
|
|
|
|
return "\n".join(self.nl_problem)
|
|
|
|
return self.nl_problem
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def load_default(obj: dict):
|
|
|
|
"""
|
|
|
|
Loads data in the "default" format
|
|
|
|
"""
|
|
|
|
fl_problem = obj.get("fl_problem")
|
|
|
|
if isinstance(fl_problem, list):
|
|
|
|
fl_problem = "\n".join(fl_problem)
|
|
|
|
return Datum(
|
|
|
|
nl_problem=obj.get("nl_problem"),
|
|
|
|
nl_solution=obj.get("nl_solution"),
|
|
|
|
fl_problem=fl_problem,
|
|
|
|
)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def load_minif2f(obj: dict):
|
|
|
|
"""
|
|
|
|
Loads minif2f data
|
|
|
|
"""
|
|
|
|
fl_problem = obj["formal_statement"].strip()
|
|
|
|
if fl_problem.startswith("--"):
|
|
|
|
return None
|
|
|
|
return Datum(
|
|
|
|
id=obj["id"],
|
|
|
|
fl_problem=fl_problem,
|
|
|
|
#header=obj["header"],
|
|
|
|
nl_problem=obj["informal_stmt"],
|
|
|
|
nl_solution=obj["informal_proof"],
|
|
|
|
)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def load(obj: dict, data_format: str):
|
|
|
|
if data_format == "default":
|
|
|
|
return Datum.load_default(obj)
|
|
|
|
elif data_format == "minif2f":
|
|
|
|
return Datum.load_minif2f(obj)
|
|
|
|
else:
|
|
|
|
raise ValueError(f"Invalid data format {data_format}")
|