"""
SAXS/Models/FormFactors.py
"""
import numpy as np
import scipy.integrate as integrate

# from https://github.com/scipy/scipy/issues/12209
def nquad_vec(f, limits):
    n_vars = len(limits)
    z = np.zeros(n_vars-1)

    def integrand(var, i):
        if i > 0:
            z[i-1] = var
        if i == n_vars - 1:
            return integrate.quad_vec(lambda x: f(*z, x), *limits[i])[0]
        else:
            res = integrate.quad_vec(lambda x: integrand(x, i+1), *limits[i])
            return res if i == 0 else res[0]

    return integrand(None, 0)

def homogeneous_sphere(q, R):
    """
    Calculate the form factor of a homogeneous sphere.
    (entirely generated by GitHub Copilot)

    Parameters
    ----------
    q : float or array-like
        The scattering vector magnitude.
    R : float
        The radius of the sphere.

    Returns
    -------
    F : float or array-like
        The form factor of the homogeneous sphere.
    """
    # Calculate the form factor using the formula for a homogeneous sphere
    qR = q * R
    F = (3 * (np.sin(qR) - qR * np.cos(qR))) / qR**3
    return F

def sphere_volume(R):
    """
    Calculate the volume of a sphere.

    Parameters
    ----------
    R : float
        The radius of the sphere.

    Returns
    -------
    V : float
        The volume of the sphere.
    """
    # Calculate the volume using the formula for a sphere
    V = (4/3) * np.pi * R**3
    return V

def spherical_shell(q, R, r):
    """
    Calculate the form factor of a spherical shell.

    Parameters
    ----------
    q : float or array-like
        The scattering vector magnitude.
    R : float
        The outer radius of the shell.
    r : float
        The inner radius of the shell.

    Returns
    -------
    F : float or array-like
        The form factor of the spherical shell.
    """
    # Calculate the form factor using the formula for a spherical shell
    v1 = sphere_volume(R)
    v2 = sphere_volume(r)
    F = (v1*homogeneous_sphere(q, R) - v2*homogeneous_sphere(q, r)) / (v1 - v2)
    return F

def ellipsoid_of_revolution(q, R, epsilon):
    """
    Calculate the form factor of an ellipsoid of revolution.

    Parameters
    ----------
    q : float or array-like
        The scattering vector magnitude.
    R : float
        The semi-major axis of the ellipsoid.
    epsilon : float
        The aspect ratio of the ellipsoid.

    Returns
    -------
    F : float or array-like
        The form factor of the ellipsoid of revolution.
    """
    # Calculate the form factor using the formula for an ellipsoid of revolution
    def r(R, epsilon, alpha):
        return R * np.sqrt(np.sin(alpha)**2 + (epsilon**2) * np.cos(alpha)**2)

    F = integrate.quad(lambda a: homogeneous_sphere(q, r(R, epsilon, a)), 0, np.pi/2)[0]
    return F

def tri_axial_ellipsoid(q, a, b, c):
    """
    Calculate the form factor of a tri-axial ellipsoid.

    Parameters
    ----------
    q : float or array-like
        The scattering vector magnitude.
    a : float
        The semi-major axis of the ellipsoid.
    b : float
        The semi-minor axis of the ellipsoid in the x-y plane.
    c : float
        The semi-minor axis of the ellipsoid in the z direction.

    Returns
    -------
    F : float or array-like
        The form factor of the tri-axial ellipsoid.
    """
    # Calculate the form factor using the formula for a tri-axial ellipsoid
    def r(a, b, c, alpha, beta):
        return np.sqrt(((a**2 * np.sin(beta)**2 + b**2 * np.cos(beta)**2) * np.sin(alpha)**2 + (c * np.cos(alpha))**2))

    F = 2/np.pi * nquad_vec(lambda alpha, beta: homogeneous_sphere(q, r(a, b, c, alpha, beta))*np.sin(alpha), [[0, np.pi/2], [0, np.pi/2]])[0]
    return F
