import unittest
import cadquery as Cq
from nhf.checks import binary_intersection, pairwise_intersection
from nhf.parts import joints, handle, metric_threads, springs
import nhf.parts.fasteners as fasteners

class TestFasteners(unittest.TestCase):

    def test_hex_nut(self):
        width = 18.9
        height = 9.8
        item = fasteners.HexNut(
            mass=float('nan'),
            diam_thread=12.0,
            pitch=1.75,
            thickness=9.8,
            width=width,
        )
        obj = item.generate()
        self.assertEqual(len(obj.vals()), 1)
        bbox = obj.val().BoundingBox()
        self.assertAlmostEqual(bbox.xlen, width)
        self.assertAlmostEqual(bbox.zlen, height)

class TestJoints(unittest.TestCase):

    def test_joint_hirth(self):
        j = joints.HirthJoint()
        obj = j.generate()
        self.assertIsInstance(
            obj.val().solids(), Cq.Solid,
            msg="Hirth joint must be in one piece")

    def test_joints_hirth_assembly(self):
        for n_tooth in [16, 20, 24]:
            with self.subTest(n_tooth=n_tooth):
                j = joints.HirthJoint()
                assembly = j.assembly()
                isect = binary_intersection(assembly)
                self.assertLess(isect.Volume(), 1e-6,
                                "Hirth joint assembly must not have intersection")

    def torsion_joint_case(self, joint: joints.TorsionJoint, slot: int):
        assert 0 <= slot and slot < joint.rider_n_slots
        assembly = joint.rider_track_assembly(slot)
        bbox = assembly.toCompound().BoundingBox()
        self.assertAlmostEqual(bbox.zlen, joint.total_height)
        self.assertAlmostEqual(bbox.xlen, joint.radius * 2)
        self.assertAlmostEqual(bbox.ylen, joint.radius * 2)
        self.assertEqual(pairwise_intersection(assembly), [])

    def test_torsion_joint(self):
        j = joints.TorsionJoint()
        for slot in range(j.rider_n_slots):
            with self.subTest(slot=slot, right_handed=False):
                self.torsion_joint_case(j, slot)
    def test_torsion_joint_right_handed(self):
        j = joints.TorsionJoint(springs.TorsionSpring(mass=float('nan'), right_handed=True))
        for slot in range(j.rider_n_slots):
            with self.subTest(slot=slot, right_handed=True):
                self.torsion_joint_case(j, slot)
    def test_torsion_joint_covered(self):
        j = joints.TorsionJoint(
            spring_hole_cover_track=True,
            spring_hole_cover_rider=True,
        )
        self.torsion_joint_case(j, 1)
    def test_torsion_joint_slot(self):
        j = joints.TorsionJoint(
            rider_slot_begin=90,
        )
        self.torsion_joint_case(j, 1)



class TestHandle(unittest.TestCase):

    def test_threaded_collision(self):
        h = handle.Handle(mount=handle.ThreadedMount())
        assembly = h.connector_insertion_assembly()
        self.assertEqual(pairwise_intersection(assembly), [])
    def test_threaded_assembly(self):
        h = handle.Handle(mount=handle.ThreadedMount())
        assembly = h.connector_insertion_assembly()
        bbox = assembly.toCompound().BoundingBox()
        self.assertAlmostEqual(bbox.xlen, h.diam)
        self.assertAlmostEqual(bbox.ylen, h.diam)
    def test_threaded_one_sided_insertion(self):
        h = handle.Handle(mount=handle.ThreadedMount())
        assembly = h.connector_one_side_insertion_assembly()
        bbox = assembly.toCompound().BoundingBox()
        self.assertAlmostEqual(bbox.xlen, h.diam)
        self.assertAlmostEqual(bbox.ylen, h.diam)
        self.assertEqual(pairwise_intersection(assembly), [])
    def test_bayonet_collision(self):
        h = handle.Handle(mount=handle.BayonetMount())
        assembly = h.connector_insertion_assembly()
        self.assertEqual(pairwise_intersection(assembly), [])
    def test_bayonet_assembly(self):
        h = handle.Handle(mount=handle.BayonetMount())
        assembly = h.connector_insertion_assembly()
        bbox = assembly.toCompound().BoundingBox()
        self.assertAlmostEqual(bbox.xlen, h.diam)
        self.assertAlmostEqual(bbox.ylen, h.diam)

    def test_bayonet_one_sided_insertion(self):
        h = handle.Handle(mount=handle.BayonetMount())
        assembly = h.connector_one_side_insertion_assembly()
        bbox = assembly.toCompound().BoundingBox()
        self.assertAlmostEqual(bbox.xlen, h.diam)
        self.assertAlmostEqual(bbox.ylen, h.diam)
        self.assertEqual(pairwise_intersection(assembly), [])

class TestMetricThreads(unittest.TestCase):

    def test_major_radius(self):
        major = 3.0
        t = metric_threads.external_metric_thread(major, 0.5, 4.0, z_start=-0.85, top_lead_in=True)
        bbox = t.val().BoundingBox()
        self.assertAlmostEqual(bbox.xlen, major, places=3)
        self.assertAlmostEqual(bbox.ylen, major, places=3)


if __name__ == '__main__':
    unittest.main()