from nhf.build import Model, TargetKind, target, assembly, submodel
from nhf.materials import Role, Material
import nhf.utils

import math
from dataclasses import dataclass, field
import cadquery as Cq

@dataclass
class Onbashira(Model):

    n_side: int = 6
    # Dimensions of each side panel
    side_width: float = 170.0

    # Side panels have different lengths
    side_length1: float = 200.0
    side_length2: float = 350.0
    side_length3: float = 400.0
    side_length4: float = 400.0

    side_thickness: float = 25.4 / 8

    # Joints between two sets of side panels
    angle_joint_thickness: float = 10.0
    # Z-axis size of each angle joint
    angle_joint_depth: float = 60.0
    # Gap of each angle joint to connect the outside to the inside
    angle_joint_gap: float = 10.0
    angle_joint_bolt_length: float = 50.0
    angle_joint_bolt_diam: float = 8.0
    # Position of the holes, with (0, 0) being the centre of each side
    angle_joint_bolt_position: list[float] = field(default_factory=lambda: [
        (20, 10),
        (60, 10),
    ])
    angle_joint_flange_thickness: float = 7.8
    angle_joint_flange_radius: float = 40.0

    # Dimensions of gun barrels
    barrel_diam: float = 25.4 * 1.5
    barrel_length: float = 300.0
    # Radius from barrel centre to axis
    rotation_radius: float = 75.0
    n_bearing_balls: int = 24
    # Size of ball bearings
    bearing_ball_diam: float = 25.4 * 1/2
    bearing_ball_gap: float = .5
    # Thickness of bearing disks
    bearing_thickness: float = 20.0
    bearing_track_radius: float = 110.0
    # Gap between the inner and outer bearing disks
    bearing_gap: float = 10.0
    bearing_disk_thickness: float = 25.4 / 8

    rotor_inner_radius: float = 40.0
    rotor_bind_bolt_diam: float = 8.0
    rotor_bind_radius: float = 85.0
    rotor_spacer_outer_diam: float = 15.0
    stator_bind_radius: float = 140.0

    material_side: Material = Material.WOOD_BIRCH
    material_bearing: Material = Material.PLASTIC_PLA
    material_bearing_ball: Material = Material.ACRYLIC_TRANSPARENT
    material_brace: Material = Material.PLASTIC_PLA

    def __post_init__(self):
        assert self.n_side >= 3
        # Bulk must be large enough for the barrel + bearing to rotate
        assert self.bulk_radius - self.side_thickness - self.bearing_thickness - self.bearing_diam > self.rotation_radius + self.barrel_diam / 2
        assert self.bearing_gap < 0.95 * self.bearing_ball_diam
        assert self.rotor_bind_bolt_diam < self.rotor_bind_radius < self.bearing_track_radius
        assert self.rotor_inner_radius < self.bearing_track_radius < self.stator_bind_radius
        assert self.angle_joint_thickness > self.side_thickness

        for (x, y) in self.angle_joint_bolt_position:
            assert y < self.angle_joint_depth / 2

    @property
    def angle_side(self) -> float:
        return 360 / self.n_side
    @property
    def side_width_inner(self) -> float:
        """
        Interior side width

        If outer width is `wi`, inner width is `wo`, each side's cross section
        is a trapezoid with sides `wi`, `wo`, and height `h` (side thickness)
        """
        theta = math.pi / self.n_side
        dt = self.side_thickness * math.tan(theta)
        return self.side_width - dt * 2
    @property
    def angle_joint_extra_width(self) -> float:
        theta = math.pi / self.n_side
        dt = self.angle_joint_thickness * math.tan(theta)
        return dt * 2


    @property
    def angle_dihedral(self) -> float:
        return 180 - self.angle_side
    @property
    def bulk_radius(self) -> float:
        """
        Radius of the bulk (surface of each side) to the centre
        """
        return self.side_width / 2 / math.tan(math.radians(self.angle_side / 2))
    @property
    def bearing_diam(self) -> float:
        return self.bearing_ball_diam + self.bearing_ball_gap

    @property
    def bearing_disk_gap(self) -> float:
        """
        Gap between two bearing disks to touch the bearing balls
        """
        diag = self.bearing_ball_diam
        dx = self.bearing_gap
        return math.sqrt(diag ** 2 - dx ** 2)

    @target(name="bearing-stator", kind=TargetKind.DXF)
    def profile_bearing_stator(self) -> Cq.Sketch:
        return (
            Cq.Sketch()
            .regularPolygon(self.side_width - self.side_thickness, self.n_side)
            .circle(self.bearing_track_radius + self.bearing_gap/2, mode="s")
            .reset()
            .regularPolygon(
                self.stator_bind_radius, self.n_side,
                mode="c", tag="bolt")
            .vertices(tag="bolt")
            .circle(self.rotor_bind_bolt_diam/2, mode="s")
        )
    def bearing_stator(self) -> Cq.Workplane:
        return (
            Cq.Workplane()
            .placeSketch(self.profile_bearing_stator())
            .extrude(self.bearing_disk_thickness)
        )
    @target(name="bearing-rotor", kind=TargetKind.DXF)
    def profile_bearing_rotor(self) -> Cq.Sketch:
        bolt_angle = 180 / self.n_side
        return (
            Cq.Sketch()
            .circle(self.bearing_track_radius - self.bearing_gap/2)
            .circle(self.rotor_inner_radius, mode="s")
            .reset()
            .regularPolygon(
                self.rotation_radius, self.n_side,
                mode="c", tag="corners")
            .vertices(tag="corners")
            .circle(self.barrel_diam/2, mode="s")
            .reset()
            .regularPolygon(
                self.rotor_bind_radius, self.n_side,
                mode="c", tag="bolt", angle=bolt_angle)
            .vertices(tag="bolt")
            .circle(self.rotor_bind_bolt_diam/2, mode="s")
        )
    def bearing_rotor(self) -> Cq.Workplane:
        return (
            Cq.Workplane()
            .placeSketch(self.profile_bearing_rotor())
            .extrude(self.bearing_disk_thickness)
        )
    @target(name="bearing-gasket", kind=TargetKind.DXF)
    def profile_bearing_gasket(self) -> Cq.Sketch:
        dr = self.bearing_ball_diam
        eps = 0.05
        return (
            Cq.Sketch()
            .circle(self.bearing_track_radius + dr)
            .circle(self.bearing_track_radius - dr, mode="s")
            .reset()
            .regularPolygon(
                self.bearing_track_radius, self.n_bearing_balls,
                mode="c", tag="corners")
            .vertices(tag="corners")
            .circle(self.bearing_ball_diam/2 * (1+eps), mode="s")
        )
    def bearing_gasket(self) -> Cq.Workplane:
        return (
            Cq.Workplane()
            .placeSketch(self.profile_bearing_gasket())
            .extrude(self.bearing_disk_thickness)
        )


    @target(name="pipe", kind=TargetKind.DXF)
    def pipe(self) -> Cq.Sketch:
        """
        The rotating pipes. Purely for decoration
        """
        pass

    def bearing_ball(self) -> Cq.Solid:
        return Cq.Solid.makeSphere(radius=self.bearing_ball_diam/2, angleDegrees1=-90)

    @target(name="rotor-spacer")
    def rotor_spacer(self) -> Cq.Solid:
        outer = Cq.Solid.makeCylinder(
            radius=self.rotor_spacer_outer_diam/2,
            height=self.bearing_disk_gap,
        )
        inner = Cq.Solid.makeCylinder(
            radius=self.rotor_bind_bolt_diam/2,
            height=self.bearing_disk_gap
        )
        return outer - inner

    def assembly_rotor(self) -> Cq.Assembly:
        z_lower = -self.bearing_disk_gap/2 - self.bearing_disk_thickness
        a = (
            Cq.Assembly()
            .addS(
                self.bearing_stator(),
                name="stator1",
                material=self.material_bearing,
                role=Role.STATOR,
                loc=Cq.Location(0, 0, self.bearing_disk_gap/2)
            )
            .addS(
                self.bearing_rotor(),
                name="rotor1",
                material=self.material_bearing,
                role=Role.ROTOR,
                loc=Cq.Location(0, 0, self.bearing_disk_gap/2)
            )
            .addS(
                self.bearing_stator(),
                name="stator2",
                material=self.material_bearing,
                role=Role.STATOR,
                loc=Cq.Location(0, 0, z_lower)
            )
            .addS(
                self.bearing_rotor(),
                name="rotor2",
                material=self.material_bearing,
                role=Role.ROTOR,
                loc=Cq.Location(0, 0, z_lower)
            )
            .addS(
                self.bearing_gasket(),
                name="gasket",
                material=self.material_bearing,
                role=Role.ROTOR,
                loc=Cq.Location(0, 0, -self.bearing_disk_thickness/2)
            )
        )
        for i in range(self.n_bearing_balls):
            ball = self.bearing_ball()
            loc = Cq.Location.rot2d(i * 360/self.n_bearing_balls) * Cq.Location(self.bearing_track_radius, 0, 0)
            a = a.addS(
                ball,
                name=f"bearing_ball{i}",
                material=self.material_bearing_ball,
                role=Role.BEARING,
                loc=loc,
            )
        return a


    def profile_side_panel(
            self,
            length: float,
            hasFrontHole: bool = False,
            hasBackHole: bool = True) -> Cq.Sketch:
        assert hasFrontHole or hasBackHole
        signs = ([1] if hasFrontHole else []) + ([-1] if hasBackHole else [])
        return (
            Cq.Sketch()
            .rect(self.side_width, length)
            .push([
                (sx * x, sy * (length/2 - y))
                for (x, y) in self.angle_joint_bolt_position
                for sx in [1, -1]
                for sy in signs
            ])
            .circle(self.angle_joint_bolt_diam/2, mode="s")
        )

    def side_panel(self, length: float, hasFrontHole: bool = True, hasBackHole: bool = True) -> Cq.Workplane:
        w = self.side_width
        sketch = self.profile_side_panel(
            length=length,
            hasFrontHole=hasFrontHole,
            hasBackHole=hasBackHole,
        )
        result = (
            Cq.Workplane()
            .placeSketch(sketch)
            .extrude(self.side_thickness)
        )
        # Bevel the edges
        intersector = (
            Cq.Workplane('XZ')
            .polyline([
                (-w/2, 0),
                (w/2, 0),
                (0, self.bulk_radius),
            ])
            .close()
            .extrude(length)
            .translate(Cq.Vector(0, length/2, 0))
        )
        # Intersect the side panel
        result = result * intersector

        # Mark all attachment points
        t = self.side_thickness
        for i, (x, y) in enumerate(self.angle_joint_bolt_position):
            px = x
            py = length / 2 - y
            result.tagAbsolute(f"holeFPI{i}", (+px,  py, t), direction="+Z")
            result.tagAbsolute(f"holeFSI{i}", (-px,  py, t), direction="+Z")
            result.tagAbsolute(f"holeFPO{i}", (+px,  py, 0), direction="-Z")
            result.tagAbsolute(f"holeFSO{i}", (-px,  py, 0), direction="-Z")
            result.tagAbsolute(f"holeBPI{i}", (+px, -py, t), direction="+Z")
            result.tagAbsolute(f"holeBSI{i}", (-px, -py, t), direction="+Z")
            result.tagAbsolute(f"holeBPO{i}", (+px, -py, 0), direction="-Z")
            result.tagAbsolute(f"holeBSO{i}", (-px, -py, 0), direction="-Z")

        return result

    @target(name="angle-joint")
    def angle_joint(self) -> Cq.Workplane:
        """
        Angular joint between two side panels. This sits at the intersection of
        4 side panels to provide compressive, shear, and tensile strength.

        To provide tensile strength along the Z-axis, the panels must be bolted
        onto the angle joint.

        The holes are marked hole(L/R)(P/S)(O/I)(i), where L/R corresponds to the two
        sections being joined, and P/S corresponds to the two facets
        (primary/secondary) being joined. O/I corresponds to the outside/inside
        """

        # Create the slot carving
        slot = (
            Cq.Sketch()
            .regularPolygon(
                self.side_width,
                self.n_side
            )
            #.regularPolygon(
            #    self.side_width_inner,
            #    self.n_side, mode="s",
            #)
        )
        slot = (
            Cq.Workplane()
            .placeSketch(slot)
            .extrude(self.angle_joint_depth)
        )

        # Construct the overall shape of the joint, and divide it into sections for printing later.
        sketch = (
            Cq.Sketch()
            .regularPolygon(
                self.side_width + self.angle_joint_extra_width,
                self.n_side
            )
            .regularPolygon(
                self.side_width - self.angle_joint_extra_width,
                self.n_side, mode="s"
            )
        )

        h = (self.bulk_radius + self.angle_joint_extra_width) * 2
        # Intersector for 1/n of the ring
        intersector = (
            Cq.Workplane()
            .sketch()
            .polygon([
                (0, 0),
                (h, 0),
                (h, h * math.tan(2 * math.pi / self.n_side))
            ])
            .finalize()
            .extrude(self.angle_joint_depth*4)
            .translate((0, 0, -self.angle_joint_depth*2))
        )
        result = (
            Cq.Workplane()
            .placeSketch(sketch)
            .extrude(self.angle_joint_depth)
            .translate((0, 0, -self.angle_joint_depth/2))
            .cut(slot.translate((0, 0, self.angle_joint_gap/2)))
            .cut(slot.translate((0, 0, -self.angle_joint_depth-self.angle_joint_gap/2)))
            .intersect(intersector)
        )
        hole_negative = Cq.Solid.makeCylinder(
            radius=self.angle_joint_bolt_diam/2,
            height=h,
            pnt=(0,0,0),
            dir=(1,0,0),
        )
        dy = self.angle_joint_gap / 2
        locrot = Cq.Location(0, 0, 0, 0, 0, 360/self.n_side)
        for (x, y) in self.angle_joint_bolt_position:
            p1 = Cq.Location((0, x, dy+y))
            p2 = Cq.Location((0, x, -dy-y))
            p1r = locrot * Cq.Location((0, -x, dy+y))
            p2r = locrot * Cq.Location((0, -x, -dy-y))
            result = result \
                - hole_negative.moved(p1) \
                - hole_negative.moved(p2) \
                - hole_negative.moved(p1r) \
                - hole_negative.moved(p2r)
        # Mark the absolute locations of the mount points
        dr = self.bulk_radius + self.angle_joint_thickness
        dr0 = self.bulk_radius
        dri = self.bulk_radius - self.angle_joint_thickness
        for i, (x, y) in enumerate(self.angle_joint_bolt_position):
            py = dy + y
            result.tagAbsolute(f"holeLPO{i}", (dr,  x,  py), direction="+X")
            result.tagAbsolute(f"holeRPO{i}", (dr,  x, -py), direction="+X")
            result.tagAbsolute(f"holeLPM{i}", (dr0, x,  py), direction="-X")
            result.tagAbsolute(f"holeRPM{i}", (dr0, x, -py), direction="-X")
            result.tagAbsolute(f"holeLPI{i}", (dri, x,  py), direction="-X")
            result.tagAbsolute(f"holeRPI{i}", (dri, x, -py), direction="-X")
            result.tagAbsolute(f"holeLSO{i}", locrot * Cq.Location(dr,  -x,  py), direction="+X")
            result.tagAbsolute(f"holeRSO{i}", locrot * Cq.Location(dr,  -x, -py), direction="+X")
            result.tagAbsolute(f"holeLSM{i}", locrot * Cq.Location(dr0, -x,  py), direction="-X")
            result.tagAbsolute(f"holeRSM{i}", locrot * Cq.Location(dr0, -x, -py), direction="-X")
            result.tagAbsolute(f"holeLSI{i}", locrot * Cq.Location(dri, -x,  py), direction="-X")
            result.tagAbsolute(f"holeRSI{i}", locrot * Cq.Location(dri, -x, -py), direction="-X")
        return result

    @target(name="angle-joint-flanged")
    def angle_joint_flanged(self) -> Cq.Workplane:
        result = self.angle_joint()
        th = math.pi / self.n_side
        r = self.bulk_radius
        flange = (
            Cq.Sketch()
            .push([
                (r, r * math.tan(th))
            ])
            .circle(self.angle_joint_flange_radius)
            .reset()
            .regularPolygon(self.side_width_inner, self.n_side, mode="i")
        )
        flange = (
            Cq.Workplane()
            .placeSketch(flange)
            .extrude(self.angle_joint_flange_thickness)
            .translate((0, 0, -self.angle_joint_flange_thickness/2))
        )
        ri = self.stator_bind_radius
        h = self.angle_joint_flange_thickness
        cyl = Cq.Solid.makeCylinder(
            radius=self.rotor_bind_bolt_diam/2,
            height=h,
            pnt=(ri * math.cos(th), ri * math.sin(th), -h/2),
        )
        result = result + flange - cyl
        result.tagAbsolute("holeStatorL", (ri * math.cos(th), ri * math.sin(th), h/2), direction="+Z")
        result.tagAbsolute("holeStatorR", (ri * math.cos(th), ri * math.sin(th), -h/2), direction="-Z")
        return result

    def assembly_section(self, **kwargs) -> Cq.Assembly:
        a = Cq.Assembly()
        side = self.side_panel(**kwargs)
        r = self.bulk_radius
        for i in range(self.n_side):
            a = a.addS(
                side,
                name=f"side{i}",
                material=self.material_side,
                role=Role.STRUCTURE | Role.DECORATION,
                loc=Cq.Location.rot2d(i*360/self.n_side) * Cq.Location(-r,0,0,90,0,90),
            )
        return a
    def assembly_ring(self, flanged=False) -> Cq.Assembly:
        a = Cq.Assembly()
        side = self.angle_joint_flanged() if flanged else self.angle_joint()
        r = self.bulk_radius
        for i in range(self.n_side):
            a = a.addS(
                side,
                name=f"side{i}",
                material=self.material_brace,
                role=Role.CASING | Role.DECORATION,
                loc=Cq.Location.rot2d(i*360/self.n_side),
            )
        return a

    @assembly()
    def assembly(self) -> Cq.Assembly:
        a = Cq.Assembly()
        a = (
            a
            .add(
                self.assembly_section(length=self.side_length1, hasFrontHole=False, hasBackHole=True),
                name="section1",
            )
            .add(
                self.assembly_ring(flanged=True),
                name="ring1",
            )
            .add(
                self.assembly_section(length=self.side_length2, hasFrontHole=True, hasBackHole=True),
                name="section2",
            )
            .add(
                self.assembly_ring(),
                name="ring2",
            )
            .add(
                self.assembly_section(length=self.side_length3, hasFrontHole=True, hasBackHole=True),
                name="section3",
            )
            .add(
                self.assembly_ring(),
                name="ring3",
            )
            .add(
                self.assembly_section(length=self.side_length4, hasFrontHole=True, hasBackHole=False),
                name="section4",
            )
        )
        for (nl, nc, nr) in [
                ("section1", "ring1", "section2"),
                ("section2", "ring2", "section3"),
                ("section3", "ring3", "section4"),
        ]:
            for i in range(self.n_side):
                j = (i + 1) % self.n_side
                for ih in range(len(self.angle_joint_bolt_position)):
                    a = a.constrain(
                        f"{nl}/side{i}?holeBSO{ih}",
                        f"{nc}/side{i}?holeLPM{ih}",
                        "Plane",
                    )
                    a = a.constrain(
                        f"{nr}/side{i}?holeFPO{ih}",
                        f"{nc}/side{i}?holeRSM{ih}",
                        "Plane",
                    )
 
        a = a.add(self.assembly_rotor(), name="rotor")
        return a.solve()