from collections import defaultdict
from typing import Union

import numpy as np
import openseespy.opensees as ops


def _construct_transform_matrix_beam(ele_tags):
    """
    Constructs the transformation matrix from the global coordinate system to the local coordinate system.

    Returns
    -------
    T : numpy.ndarray
        A 3x3 transformation matrix mapping global coordinates to local coordinates.
    ndim: int
        The dimensions of the model.
    """
    # Calculate the local coordinate axes
    ele_tags = np.atleast_1d(ele_tags)
    ele_tags = [int(tag) for tag in ele_tags]
    T = []
    ndim = 2
    for etag in ele_tags:
        ele_nodes = ops.eleNodes(etag)
        coords = ops.nodeCoord(ele_nodes[0])
        ndim_ = len(coords)
        if ndim_ > ndim:
            ndim = ndim_
        xaxis = ops.eleResponse(etag, "xaxis")
        yaxis = ops.eleResponse(etag, "yaxis")
        zaxis = ops.eleResponse(etag, "zaxis")
        T.append([xaxis, yaxis, zaxis])
    return np.array(T), ndim, ele_tags


def transform_beam_uniform_load(
    ele_tags: Union[int, list[int], tuple, np.ndarray[int]],
    wx: Union[float, list[float], np.ndarray[float]] = 0.0,
    wy: Union[float, list[float], np.ndarray[float]] = 0.0,
    wz: Union[float, list[float], np.ndarray[float]] = 0.0,
) -> None:
    """
    Transforms a uniformly distributed beam load from the global coordinate system to the local coordinate system.

    .. Note::
        This function will automatically call the
        `EleLoad Command <https://opensees.berkeley.edu/wiki/index.php/EleLoad_Command>`_ to generate element loads.
        However, you need to create ``timeSeries`` and load ``pattern`` objects externally in advance.
        The load generated by this function will belong to the load pattern closest to it.

    Parameters
    -----------
    ele_tags : int, list, tuple, np.ndarray
        Beam element tags.
    wx : float, list, np.ndarray, default=0.0
        Uniformly distributed load in the `global X` direction.
        If a list or numpy array is provided, the length should be the same as the number of elements.
    wy : float, list, np.ndarray, default=0.0
        Uniformly distributed load in the `global Y` direction.
        If a list or numpy array is provided, the length should be the same as the number of elements.
    wz : float, list, np.ndarray, default=0.0
        Uniformly distributed load in the `global Z` direction.
        If a list or numpy array is provided, the length should be the same as the number of elements.
    """
    T, ndim, ele_tags = _construct_transform_matrix_beam(ele_tags)
    q_globals = np.atleast_2d(np.array([wx, wy, wz]))
    if q_globals.shape[0] == 1 and T.shape[0] > 1:
        q_globals = np.repeat(q_globals, T.shape[0], axis=0)
    q_locals = np.einsum("nij,nj->ni", T, q_globals)
    if ndim == 3:
        for qlocal, etag in zip(q_locals, ele_tags):
            qlocal = [float(q) for q in qlocal]
            ops.eleLoad("-ele", etag, "-type", "-beamUniform", qlocal[1], qlocal[2], qlocal[0])
    else:
        for qlocal, etag in zip(q_locals, ele_tags):
            qlocal = [float(q) for q in qlocal]
            ops.eleLoad("-ele", etag, "-type", "-beamUniform", qlocal[1], qlocal[0])


def transform_beam_point_load(
    ele_tags: Union[int, list[int], tuple, np.ndarray[int]],
    px: Union[float, list[float], np.ndarray[float]] = 0.0,
    py: Union[float, list[float], np.ndarray[float]] = 0.0,
    pz: Union[float, list[float], np.ndarray[float]] = 0.0,
    xl: Union[float, list[float], np.ndarray[float]] = 0.5,
) -> None:
    """
    Transforms point loads for beam elements from global to local coordinates.

    Parameters
    ----------
    ele_tags : int, list, tuple, np.ndarray
        Beam element tags.
    px : float, list, np.ndarray, default=0.0
        Point load in the `global X` direction.
        If a list or numpy array is provided, the length should be the same as the number of elements.
    py : float, list, np.ndarray, default=0.0
        Point load in the `global Y` direction.
        If a list or numpy array is provided, the length should be the same as the number of elements.
    pz : float, list, np.ndarray, default=0.0
        Point load in the `global Z` direction.
        If a list or numpy array is provided, the length should be the same as the number of elements.
    xl : float, list, np.ndarray, default=0.5
        The position of the point load along the beam element in a local coord system.
        If a list or numpy array is provided, the length should be the same as the number of elements.
    """
    # Compute global positions of point loads
    T, ndim, ele_tags = _construct_transform_matrix_beam(ele_tags)
    p_globals = np.atleast_2d(np.array([px, py, pz]))
    if p_globals.shape[0] == 1 and T.shape[0] > 1:
        p_globals = np.repeat(p_globals, T.shape[0], axis=0)
    xls = np.atleast_2d(xl)
    if xls.shape[0] == 1 and T.shape[0] > 1:
        xls = np.repeat(xls, T.shape[0], axis=0)

    # Transform point loads and positions to local coordinates
    p_locals = np.einsum("nij,nj->ni", T, p_globals)

    if ndim == 3:
        for plocal, etag, xl in zip(p_locals, ele_tags, xls):
            plocal = [float(p) for p in plocal]
            ops.eleLoad("-ele", etag, "-type", "-beamPoint", plocal[1], plocal[2], xl[0], plocal[0])
    else:
        for plocal, etag, xl in zip(p_locals, ele_tags, xls):
            plocal = [float(p) for p in plocal]
            ops.eleLoad("-ele", etag, "-type", "-beamPoint", plocal[1], xl[0], plocal[0])


def transform_surface_uniform_load(
    ele_tags: Union[int, list[int], tuple, np.ndarray[int]],
    p: Union[float, list[float], np.ndarray[float]] = 0.0,
) -> None:
    """
    Converts uniform surface loads into equivalent nodal forces in the global coordinate system.
    According to the static equivalence principle, the distributed load is equivalent to the node load.

    .. Note::
        This function will automatically call the
        `NodalLoad Command <https://opensees.berkeley.edu/wiki/index.php?title=NodalLoad_Command>`_ to
        generate nodal loads.
        However, you need to create ``timeSeries`` and load ``pattern`` objects externally in advance.
        The load generated by this function will belong to the load pattern closest to it.

    Parameters
    ----------
    ele_tags : int, list, tuple, np.ndarray
        Surface element tags.
    p : float, list, np.ndarray, default=0.0
        Uniform surface load magnitude (per unit area) along the surface normal direction.
        The positive direction of the normal is obtained by the cross-product of the I-J and J-K edges.
        If a list or numpy array is provided, the length should be the same as the number of elements.
    """
    ele_tags = np.atleast_1d(ele_tags)
    ele_tags = [int(tag) for tag in ele_tags]
    uniform_loads = [p] * len(ele_tags) if isinstance(p, (int, float)) else p

    nodal_forces = defaultdict(lambda: np.zeros(3))
    nodal_dofs = {}

    for etag, load in zip(ele_tags, uniform_loads):
        node_ids = ops.eleNodes(etag)
        vertices = np.array([ops.nodeCoord(node_id) for node_id in node_ids])
        # Compute area and normal based on an element type
        if len(node_ids) == 3:  # Triangle
            area, normal = _compute_tri_area_and_normal(vertices)
        elif len(node_ids) == 4:  # Quadrilateral
            area, normal = _compute_quad_area_and_normal(vertices)
        else:
            raise ValueError(f"Unsupported element type with {len(node_ids)} nodes.")  # noqa: TRY003
        # Compute total force on the element
        element_force = load * area * normal  # Total force vector in global coordinates
        # Distribute force to each node equally
        force_per_node = element_force / len(node_ids)

        # Accumulate forces to the global nodal forces dictionary
        for node_id in node_ids:
            nodal_forces[node_id] += force_per_node
            nodal_dofs[node_id] = ops.getNDF(node_id)[0]
    for key, value in nodal_forces.items():
        ndf = nodal_dofs[key]
        if ndf == 3:
            ops.load(key, *value)
        elif ndf == 6:
            ops.load(key, *value, 0, 0, 0)
        elif ndf == 2:
            ops.load(key, *value[:2])


def _compute_tri_area_and_normal(vertices):
    """
    Compute the area and normal vector of a triangular element.

    Parameters
    ----------
    vertices : numpy.ndarray
        Coordinates of the triangle's vertices, shape (3, 3).

    Returns
    -------
    area : float
        Area of the triangle.
    normal : numpy.ndarray
        Unit normal vector of the triangle, shape (3,).
    """
    # Edges: IJ and JK
    edge_ij = vertices[1] - vertices[0]
    edge_jk = vertices[2] - vertices[1]

    # Compute cross product
    cross_product = np.cross(edge_ij, edge_jk)
    norm = np.linalg.norm(cross_product)
    area = 0.5 * norm
    # Normalize the normal vector
    normal = cross_product / norm
    return area, normal


def _compute_quad_area_and_normal(vertices):
    """
    Compute the area and normal vector of a quadrilateral element.

    Parameters
    ----------
    vertices : numpy.ndarray
        Coordinates of the quadrilateral's vertices, shape (4, 3).

    Returns
    -------
    area : float
        Area of the quadrilateral.
    normal : numpy.ndarray
        Unit normal vector of the quadrilateral, shape (3,).
    """
    # Divide quadrilateral into two triangles
    triangle1 = vertices[:3]
    triangle2 = np.array([vertices[0], vertices[2], vertices[3]])

    # Compute areas and normals
    area1, normal1 = _compute_tri_area_and_normal(triangle1)
    area2, normal2 = _compute_tri_area_and_normal(triangle2)

    # Average the normals and normalize
    normal = (normal1 + normal2) / 2.0
    # normal = normal / np.linalg.norm(normal)
    return area1 + area2, normal
