"""
Utility functions for cadquery objects
"""
import functools, math
from typing import Optional, Union, Tuple, cast
import cadquery as Cq
from cadquery.occ_impl.solver import ConstraintSpec
from nhf import Role
from nhf.materials import KEY_ITEM, KEY_MATERIAL

# Bug fixes
def _subloc(self, name: str) -> Tuple[Cq.Location, str]:
    """
       Calculate relative location of an object in a subassembly.

       Returns the relative positions as well as the name of the top assembly.
       """

    rv = Cq.Location()
    obj = self.objects[name]
    name_out = name

    if obj not in self.children and obj is not self:
        locs = []
        while not obj.parent is self:
            locs.append(obj.loc)
            obj = cast(Cq.Assembly, obj.parent)
            name_out = obj.name

        rv = functools.reduce(lambda l1, l2: l2 * l1, locs)

    return (rv, name_out)
Cq.Assembly._subloc = _subloc

### Vector arithmetic

def location_sub(self: Cq.Location, rhs: Cq.Location) -> Cq.Vector:
    (x1, y1, z1), _ = self.toTuple()
    (x2, y2, z2), _ = rhs.toTuple()
    return Cq.Vector(x1 - x2, y1 - y2, z1 - z2)
Cq.Location.__sub__ = location_sub

def from2d(x: float, y: float, rotate: float=0.0) -> Cq.Location:
    return Cq.Location((x, y, 0), (0, 0, 1), rotate)
Cq.Location.from2d = from2d

def rot2d(angle: float) -> Cq.Location:
    return Cq.Location((0, 0, 0), (0, 0, 1), angle)
Cq.Location.rot2d = rot2d

def is2d(self: Cq.Location) -> bool:
    (_, _, z), (rx, ry, _) = self.toTuple()
    return z == 0 and rx == 0 and ry == 0
Cq.Location.is2d = is2d

def scale(self: Cq.Location, fac: float) -> bool:
    (x, y, z), (rx, ry, rz) = self.toTuple()
    return Cq.Location(x*fac, y*fac, z*fac, rx, ry, rz)
Cq.Location.scale = scale

def to2d(self: Cq.Location) -> Tuple[Tuple[float, float], float]:
    """
    Returns position and angle
    """
    (x, y, z), (rx, ry, rz) = self.toTuple()
    assert z == 0
    assert rx == 0
    assert ry == 0
    return (x, y), rz
Cq.Location.to2d = to2d

def to2d_pos(self: Cq.Location) -> Tuple[float, float]:
    """
    Returns position and angle
    """
    (x, y), _ = self.to2d()
    return x, y
Cq.Location.to2d_pos = to2d_pos

def to2d_rot(self: Cq.Location) -> float:
    """
    Returns position and angle
    """
    _, r = self.to2d()
    return r
Cq.Location.to2d_rot = to2d_rot


def with_angle_2d(self: Cq.Location, angle: float) -> Tuple[float, float]:
    """
    Returns position and angle
    """
    x, y = self.to2d_pos()
    return Cq.Location.from2d(x, y, angle)
Cq.Location.with_angle_2d = with_angle_2d

def flip_x(self: Cq.Location) -> Cq.Location:
    (x, y), a = self.to2d()
    return Cq.Location.from2d(-x, y, 180 - a)
Cq.Location.flip_x = flip_x
def flip_y(self: Cq.Location) -> Cq.Location:
    (x, y), a = self.to2d()
    return Cq.Location.from2d(x, -y, -a)
Cq.Location.flip_y = flip_y

def boolean(
        self: Cq.Sketch,
        obj: Union[Cq.Face, Cq.Sketch, Cq.Compound],
        **kwargs) -> Cq.Sketch:
    """
    Performs Boolean operation between a sketch and a sketch-like object
    """
    return (
        self
        .reset()
        # Has to be 0, 0. Translation doesn't work.
        .push([(0, 0)])
        .each(lambda _: obj, **kwargs)
    )
Cq.Sketch.boolean = boolean

### Tags

def tagPoint(self, tag: str):
    """
    Adds a vertex that can be used in `Point` constraints.
    """
    vertex = Cq.Vertex.makeVertex(0, 0, 0)
    self.eachpoint(vertex.moved, useLocalCoordinates=True).tag(tag)

Cq.Workplane.tagPoint = tagPoint

def tagPlane(self, tag: str,
             direction: Union[str, Cq.Vector, Tuple[float, float, float]] = '+Z'):
    """
    Adds a phantom `Cq.Edge` in the given location which can be referenced in a
    `Axis`, `Point`, or `Plane` constraint.
    """
    if isinstance(direction, str):
        x, y, z = 0, 0, 0
        assert len(direction) == 2
        sign, axis = direction
        if axis in ('z', 'Z'):
            z = 1
        elif axis in ('y', 'Y'):
            y = 1
        elif axis in ('x', 'X'):
            x = 1
        else:
            assert False, "Axis must be one of x,y,z"
        if sign == '+':
            sign = 1
        elif sign == '-':
            sign = -1
        else:
            assert False, "Sign must be one of +/-"
        v = Cq.Vector(x, y, z) * sign
    else:
        v = Cq.Vector(direction)
    edge = Cq.Edge.makeLine(v * (-1), v)
    return self.eachpoint(edge.located, useLocalCoordinates=True).tag(tag)

Cq.Workplane.tagPlane = tagPlane

def make_sphere(r: float = 2) -> Cq.Solid:
    """
    Makes a full sphere. The default function makes a hemisphere
    """
    return Cq.Solid.makeSphere(r, angleDegrees1=-90)
def make_arrow(size: float = 2) -> Cq.Workplane:
    cone = Cq.Solid.makeCone(
        radius1 = size,
        radius2 = 0,
        height=size)
    result = (
        Cq.Workplane("XY")
        .cylinder(radius=size / 2, height=size, centered=(True, True, False))
        .union(cone.located(Cq.Location((0, 0, size))))
    )
    result.faces("<Z").tag("dir_rev")
    return result

def to_marker_name(tag: str) -> str:
    return tag.replace("?", "__T").replace("/", "__Z") + "_marker"

COLOR_MARKER = Cq.Color(0, 1, 1, 1)

def mark_point(self: Cq.Assembly,
               tag: str,
               size: float = 2,
               color: Cq.Color = COLOR_MARKER) -> Cq.Assembly:
    """
    Adds a marker to make a point visible
    """
    name = to_marker_name(tag)
    return (
        self
        .add(make_sphere(size), name=name, color=color)
        .constrain(tag, name, "Point")
    )

Cq.Assembly.markPoint = mark_point

def mark_plane(self: Cq.Assembly,
               tag: str,
               size: float = 2,
               color: Cq.Color = COLOR_MARKER) -> Cq.Assembly:
    """
    Adds a marker to make a plane visible
    """
    name = to_marker_name(tag)
    return (
        self
        .add(make_arrow(size), name=name, color=color)
        .constrain(tag, f"{name}?dir_rev", "Plane", param=180)
    )

Cq.Assembly.markPlane = mark_plane

def get_abs_location(self: Cq.Assembly,
                     tag: str) -> Cq.Location:
    """
    Gets the location of a tag

    BUG: Currently bugged. See `nhf/test.py` for example
    """
    name, shape = self._query(tag)
    loc_self = Cq.Location(shape.Center())
    loc_parent, _ = self._subloc(name)
    loc = loc_parent * loc_self
    return loc

Cq.Assembly.get_abs_location = get_abs_location

def get_abs_direction(self: Cq.Assembly,
                      tag: str) -> Cq.Location:
    """
    Gets the location of a tag
    """
    name, shape = self._query(tag)
    # Must match `cadquery.occ_impl.solver.ConstraintSpec._getAxis`
    if isinstance(shape, Cq.Face):
        vec_dir = shape.normalAt()
    elif isinstance(shape, Cq.Edge) and shape.geomType() != "CIRCLE":
        vec_dir = shape.tangentAt()
    elif isinstance(shape, Cq.Edge) and shape.geomType() == "CIRCLE":
        vec_dir = shape.normal()
    else:
        raise ValueError(f"Cannot construct Axis for {shape}")
    loc_self = Cq.Location(vec_dir)
    loc_parent, _ = self._subloc(name)
    loc = loc_parent * loc_self
    return loc
Cq.Assembly.get_abs_direction = get_abs_direction


# Tallying functions

def assembly_this_mass(self: Cq.Assembly) -> Optional[float]:
    """
    Gets the mass of an assembly, without considering its components.
    """
    if item := self.metadata.get(KEY_ITEM):
        return item.mass
    elif material := self.metadata.get(KEY_MATERIAL):
        vol = self.toCompound().Volume()
        return (vol / 1000) * material.density
    else:
        return None

def total_mass(self: Cq.Assembly) -> float:
    """
    Calculates the total mass in units of g
    """
    total = 0.0
    for _, a in self.traverse():
        if m := assembly_this_mass(a):
            total += m
    return total
Cq.Assembly.total_mass = total_mass

def centre_of_mass(self: Cq.Assembly) -> Optional[float]:
    moment = Cq.Vector()
    total = 0.0
    for n, a in self.traverse():
        if m := assembly_this_mass(a):
            moment += m * a.toCompound().Center()
            total += m
    if total == 0.0:
        return None
    return moment / total
Cq.Assembly.centre_of_mass = centre_of_mass