from dataclasses import dataclass, field
from typing import Optional
import math
import cadquery as Cq
from nhf.parts.springs import TorsionSpring
from nhf import Role
import nhf.utils

TOL = 1e-6

@dataclass(frozen=True)
class HirthJoint:
    """
    A Hirth joint attached to a cylindrical base
    """

    # r
    radius: float = 60
    # r_i
    radius_inner: float = 40
    base_height: float = 20
    n_tooth: float = 16
    # h_o
    tooth_height: float = 16

    def __post_init__(self):
        # Ensures tangent doesn't blow up
        assert self.n_tooth >= 5
        assert self.radius > self.radius_inner

    @property
    def tooth_angle(self):
        return 360 / self.n_tooth

    @property
    def total_height(self):
        return self.base_height + self.tooth_height

    @property
    def joint_height(self):
        return 2 * self.base_height + self.tooth_height


    def generate(self, is_mated=False, tol=0.01):
        """
        is_mated: If set to true, rotate the teeth so they line up at 0 degrees.

        FIXME: Mate is not exact when number of tooth is low
        """
        phi = math.radians(self.tooth_angle)
        alpha = 2 * math.atan(self.radius / self.tooth_height * math.tan(phi/2))
        #alpha = math.atan(self.radius * math.radians(180 / self.n_tooth) / self.tooth_height)
        gamma = math.radians(90 / self.n_tooth)
        # Tooth half height
        l = self.radius * math.cos(gamma)
        a = self.radius * math.sin(gamma)
        t = a / math.tan(alpha / 2)
        beta = math.asin(t / l)
        dx = self.tooth_height * math.tan(alpha / 2)
        profile = (
            Cq.Workplane('YZ')
            .polyline([
                (0, 0),
                (dx, self.tooth_height),
                (-dx, self.tooth_height),
            ])
            .close()
            .extrude(-self.radius)
            .val()
            .rotate((0, 0, 0), (0, 1, 0), math.degrees(beta))
            .moved(Cq.Location((0, 0, self.base_height)))
        )
        core = Cq.Solid.makeCylinder(
            radius=self.radius_inner,
            height=self.tooth_height,
            pnt=(0, 0, self.base_height),
        )
        angle_offset = self.tooth_angle / 2 if is_mated else 0
        result = (
            Cq.Workplane('XY')
            .cylinder(
                radius=self.radius,
                height=self.base_height + self.tooth_height,
                centered=(True, True, False))
            .faces(">Z")
            .tag("bore")
            .cut(core)
            .polarArray(
                radius=self.radius,
                startAngle=angle_offset,
                angle=360,
                count=self.n_tooth)
            .cutEach(
                lambda loc: profile.moved(loc),
            )
        )
        (
            result
            .polyline([
                (0, 0, self.base_height),
                (0, 0, self.base_height + self.tooth_height)
            ], forConstruction=True)
            .tag("mate")
        )
        (
            result
            .polyline([(0, 0, 0), (1, 0, 0)], forConstruction=True)
            .tag("dirX")
        )
        (
            result
            .polyline([(0, 0, 0), (0, 1, 0)], forConstruction=True)
            .tag("dirY")
        )
        return result

    def add_constraints(self,
                        assembly: Cq.Assembly,
                        parent: str,
                        child: str,
                        offset: int = 0):
        angle = offset * self.tooth_angle
        (
            assembly
            .constrain(f"{parent}?mate", f"{child}?mate", "Plane")
            .constrain(f"{parent}?dirX", f"{child}?dirX",
                       "Axis", param=angle)
            .constrain(f"{parent}?dirY", f"{child}?dirX",
                       "Axis", param=90 - angle)
        )

    def assembly(self, offset: int = 1):
        """
        Generate an example assembly
        """
        tab = (
            Cq.Workplane('XY')
            .box(100, 10, 2, centered=False)
        )
        obj1 = (
            self.generate()
            .faces(tag="bore")
            .cboreHole(
                diameter=10,
                cboreDiameter=20,
                cboreDepth=3)
            .union(tab)
        )
        obj2 = (
            self.generate(is_mated=True)
            .union(tab)
        )
        result = (
            Cq.Assembly()
            .addS(obj1, name="obj1", role=Role.PARENT)
            .addS(obj2, name="obj2", role=Role.CHILD)
        )
        self.add_constraints(
            result,
            parent="obj1",
            child="obj2",
            offset=offset)
        return result.solve()

@dataclass
class TorsionJoint:
    """
    This jonit consists of a rider puck on a track puck. IT is best suited if
    the radius has to be small and vertical space is abundant.

    The rider part consists of:
    1. A cylinderical base
    2. A annular extrusion with the same radius as the base, but with slots
        carved in
    3. An annular rider

    The track part consists of:
    1. A cylindrical base
    2. A slotted annular extrusion where the slot allows the spring to rest
    3. An outer and an inner annuli which forms a track the rider can move on
    """
    spring: TorsionSpring = field(default_factory=lambda: TorsionSpring(
        mass=float('nan'),
        radius=10.0,
        thickness=2.0,
        height=15.0,
        tail_length=35.0,
        right_handed=False,
    ))

    # Radius limit for rotating components
    radius_track: float = 40
    radius_rider: float = 38
    track_disk_height: float = 10
    rider_disk_height: float = 8

    radius_axle: float = 6

    # If true, cover the spring hole. May make it difficult to insert the spring
    # considering the stiffness of torsion spring steel.
    spring_hole_cover_track: bool = False
    spring_hole_cover_rider: bool = False

    groove_radius_outer: float = 35
    groove_radius_inner: float = 20
    # Gap on inner groove to ease movement
    groove_inner_gap: float = 0.2
    groove_depth: float = 5
    rider_gap: float = 1
    rider_n_slots: float = 4

    # Degrees of the first and last rider slots
    rider_slot_begin: float = 0
    rider_slot_span: float = 90


    def __post_init__(self):
        assert self.radius_track > self.groove_radius_outer
        assert self.radius_rider > self.groove_radius_outer > self.groove_radius_inner + self.groove_inner_gap
        assert self.groove_radius_inner > self.spring.radius > self.radius_axle
        assert self.spring.height > self.groove_depth, "Groove is too deep"
        assert self.groove_depth < self.spring.height - self.spring.thickness * 2
        if self.rider_n_slots == 1:
            assert self.rider_slot_span == 0.0, "Non-zero span is impossible with multiple riders"

    @property
    def total_height(self):
        """
        Total height counting from bottom to top
        """
        return self.track_disk_height + self.rider_disk_height + self.spring.height

    @property
    def radius(self):
        """
        Maximum radius of this joint
        """
        return max(self.radius_rider, self.radius_track)

    def _slot_polygon(self, flip: bool=False):
        r1 = self.spring.radius_inner
        r2 = self.spring.radius
        flip = flip != self.spring.right_handed
        if flip:
            r1 = -r1
            r2 = -r2
        return [
            (0, r2),
            (self.spring.tail_length, r2),
            (self.spring.tail_length, r1),
            (0, r1),
        ]
    def _directrix(self, height, theta=0):
        c, s = math.cos(theta), math.sin(theta)
        r2 = self.spring.radius
        l = self.spring.tail_length
        if self.spring.right_handed:
            r2 = -r2
        # This is (0, r2) and (l, r2) transformed by right handed rotation
        # matrix `[[c, -s], [s, c]]`
        return [
            (-s * r2,         c * r2, height),
            (c * l - s * r2, s * l + c * r2, height),
        ]

    def track(self):
        # TODO: Cover outer part of track only. Can we do this?
        groove_profile = (
            Cq.Sketch()
            .circle(self.radius_track)
            .circle(self.groove_radius_outer, mode='s')
            .circle(self.groove_radius_inner, mode='a')
            .circle(self.spring.radius, mode='s')
        )
        spring_hole_profile = (
            Cq.Sketch()
            .circle(self.radius_track)
            .circle(self.spring.radius, mode='s')
        )
        slot_height = self.spring.thickness
        if not self.spring_hole_cover_track:
            slot_height += self.groove_depth
        slot = (
            Cq.Workplane('XY')
            .sketch()
            .polygon(self._slot_polygon(flip=False))
            .finalize()
            .extrude(slot_height)
            .val()
        )
        result = (
            Cq.Workplane('XY')
            .cylinder(
                radius=self.radius_track,
                height=self.track_disk_height,
                centered=(True, True, False))
            .faces('>Z')
            .tag("spring")
            .placeSketch(spring_hole_profile)
            .extrude(self.spring.thickness)
            # If the spring hole profile is not simply connected, this workplane
            # will have to be created from the `spring-mate` face.
            .faces('>Z')
            .placeSketch(groove_profile)
            .extrude(self.groove_depth)
            .faces('>Z')
            .hole(self.radius_axle * 2)
            .cut(slot.moved(Cq.Location((0, 0, self.track_disk_height))))
        )
        result.faces("<Z").tag("bot")
        # Insert directrix
        result.polyline(self._directrix(self.track_disk_height),
                        forConstruction=True).tag("dir")
        return result

    def rider(self, rider_slot_begin=None, reverse_directrix_label=False):
        if not rider_slot_begin:
            rider_slot_begin = self.rider_slot_begin
        def slot(loc):
            wire = Cq.Wire.makePolygon(self._slot_polygon(flip=False))
            face = Cq.Face.makeFromWires(wire)
            return face.located(loc)
        wall_profile = (
            Cq.Sketch()
            .circle(self.radius_rider, mode='a')
            .circle(self.spring.radius, mode='s')
            .parray(
                r=0,
                a1=rider_slot_begin,
                da=self.rider_slot_span,
                n=self.rider_n_slots)
            .each(slot, mode='s')
            #.circle(self._radius_wall, mode='a')
        )
        contact_profile = (
            Cq.Sketch()
            .circle(self.groove_radius_outer, mode='a')
            .circle(self.groove_radius_inner + self.groove_inner_gap, mode='s')
        )
        if not self.spring_hole_cover_rider:
            contact_profile = (
                contact_profile
                .parray(
                    r=0,
                    a1=rider_slot_begin,
                    da=self.rider_slot_span,
                    n=self.rider_n_slots)
                .each(slot, mode='s')
                .reset()
            )
            #.circle(self._radius_wall, mode='a')
        middle_height = self.spring.height - self.groove_depth - self.rider_gap - self.spring.thickness
        result = (
            Cq.Workplane('XY')
            .cylinder(
                radius=self.radius_rider,
                height=self.rider_disk_height,
                centered=(True, True, False))
            .faces('>Z')
            .tag("spring")
            .workplane()
            .placeSketch(wall_profile)
            .extrude(middle_height)
            .faces(tag="spring")
            .workplane()
            # The top face might not be in one piece.
            .workplane(offset=middle_height)
            .placeSketch(contact_profile)
            .extrude(self.groove_depth + self.rider_gap)
            .faces(tag="spring")
            .workplane()
            .circle(self.spring.radius_inner)
            .extrude(self.spring.height)
            .faces("<Z")
            .workplane()
            .hole(self.radius_axle * 2)
        )
        theta_begin = -math.radians(rider_slot_begin)
        theta_span = math.radians(self.rider_slot_span)
        if self.rider_n_slots <= 1:
            theta_step = 0
        elif abs(math.remainder(self.rider_slot_span, 360)) < TOL:
            theta_step = theta_span / self.rider_n_slots
        else:
            theta_step = theta_span / (self.rider_n_slots - 1)
        for i in range(self.rider_n_slots):
            theta = theta_begin - i * theta_step
            j = self.rider_n_slots - i - 1 if reverse_directrix_label else i
            result.polyline(self._directrix(self.rider_disk_height, theta),
                            forConstruction=True).tag(f"dir{j}")
        return result

    def rider_track_assembly(self, directrix: int = 0, deflection: float = 0):
        rider = self.rider()
        track = self.track()
        spring = self.spring.assembly(deflection=deflection)
        result = (
            Cq.Assembly()
            .addS(spring, name="spring", role=Role.DAMPING)
            .addS(track, name="track", role=Role.PARENT)
            .addS(rider, name="rider", role=Role.CHILD)
        )
        TorsionJoint.add_constraints(
            result,
            rider="rider", track="track", spring="spring",
            directrix=directrix)
        return result.solve()

    @staticmethod
    def add_constraints(assembly: Cq.Assembly,
                        spring: str,
                        rider: Optional[str] = None,
                        track: Optional[str] = None,
                        directrix: int = 0):
        """
        Add the necessary constraints to a RT assembly
        """
        if track:
            (
                assembly
                .constrain(f"{track}?spring", f"{spring}?top", "Plane")
                .constrain(f"{track}?dir", f"{spring}?dir_top",
                           "Axis", param=0)
            )
        if rider:
            (
                assembly
                .constrain(f"{rider}?spring", f"{spring}?bot", "Plane")
                .constrain(f"{rider}?dir{directrix}", f"{spring}?dir_bot",
                           "Axis", param=0)
            )