import jax.numpy as jnp
import jax.lax as lax
import numpy as np

from typing import List, Optional, Type

from autoarray.inversion.inversion.settings import SettingsInversion

from autoarray import numba_util
from autoarray import exc


def curvature_matrix_via_w_tilde_from(
    w_tilde: np.ndarray, mapping_matrix: np.ndarray
) -> np.ndarray:
    """
    Returns the curvature matrix `F` (see Warren & Dye 2003) from `w_tilde`.

    The dimensions of `w_tilde` are [image_pixels, image_pixels], meaning that for datasets with many image pixels
    this matrix can take up 10's of GB of memory. The calculation of the `curvature_matrix` via this function will
    therefore be very slow, and the method `curvature_matrix_via_w_tilde_curvature_preload_imaging_from` should be used
    instead.

    Parameters
    ----------
    w_tilde
        A matrix of dimensions [image_pixels, image_pixels] that encodes the convolution or NUFFT of every image pixel
        pair on the noise map.
    mapping_matrix
        The matrix representing the mappings between sub-grid pixels and pixelization pixels.

    Returns
    -------
    ndarray
        The curvature matrix `F` (see Warren & Dye 2003).
    """
    return jnp.dot(mapping_matrix.T, jnp.dot(w_tilde, mapping_matrix))


def curvature_matrix_with_added_to_diag_from(
    curvature_matrix: np.ndarray,
    value: float,
    no_regularization_index_list: Optional[List] = None,
) -> np.ndarray:
    """
    It is common for the `curvature_matrix` computed to not be positive-definite, leading for the inversion
    via `np.linalg.solve` to fail and raise a `LinAlgError`.

    In many circumstances, adding a small numerical value of `1.0e-8` to the diagonal of the `curvature_matrix`
    makes it positive definite, such that the inversion is performed without raising an error.

    This function adds this numerical value to the diagonal of the curvature matrix.

    Parameters
    ----------
    curvature_matrix
        The curvature matrix which is being constructed in order to solve a linear system of equations.
    """
    return curvature_matrix.at[
        no_regularization_index_list, no_regularization_index_list
    ].add(value)


def curvature_matrix_mirrored_from(
    curvature_matrix: np.ndarray,
) -> np.ndarray:
    # Copy the original matrix and its transpose
    m1 = curvature_matrix
    m2 = curvature_matrix.T

    # For each entry, prefer the non-zero value from either the matrix or its transpose
    mirrored = jnp.where(m1 != 0, m1, m2)

    return mirrored


def curvature_matrix_via_mapping_matrix_from(
    mapping_matrix: np.ndarray,
    noise_map: np.ndarray,
    add_to_curvature_diag: bool = False,
    no_regularization_index_list: Optional[List] = None,
    settings: SettingsInversion = SettingsInversion(),
) -> np.ndarray:
    """
    Returns the curvature matrix `F` from a blurred mapping matrix `f` and the 1D noise-map $\sigma$
     (see Warren & Dye 2003).

    Parameters
    ----------
    mapping_matrix
        The matrix representing the mappings (these could be blurred or transfomed) between sub-grid pixels and
        pixelization pixels.
    noise_map
        Flattened 1D array of the noise-map used by the inversion during the fit.
    """
    array = mapping_matrix / noise_map[:, None]
    curvature_matrix = jnp.dot(array.T, array)

    if add_to_curvature_diag and len(no_regularization_index_list) > 0:
        curvature_matrix = curvature_matrix_with_added_to_diag_from(
            curvature_matrix=curvature_matrix,
            value=settings.no_regularization_add_to_curvature_diag_value,
            no_regularization_index_list=no_regularization_index_list,
        )

    return curvature_matrix


def mapped_reconstructed_data_via_mapping_matrix_from(
    mapping_matrix: np.ndarray, reconstruction: np.ndarray
) -> np.ndarray:
    """
    Returns the reconstructed data vector from the blurred mapping matrix `f` and solution vector *S*.

    Parameters
    ----------
    mapping_matrix
        The matrix representing the blurred mappings between sub-grid pixels and pixelization pixels.

    """
    return jnp.dot(mapping_matrix, reconstruction)


def mapped_reconstructed_data_via_w_tilde_from(
    w_tilde: np.ndarray, mapping_matrix: np.ndarray, reconstruction: np.ndarray
) -> np.ndarray:
    """
    Returns the reconstructed data vector from the unblurred mapping matrix `M`,
    the reconstruction vector `s`, and the PSF convolution operator `w_tilde`.

    Equivalent to:
        reconstructed = (W @ M) @ s
                      = W @ (M @ s)

    Parameters
    ----------
    w_tilde
        Array of shape [image_pixels, image_pixels], the PSF convolution operator.
    mapping_matrix
        Array of shape [image_pixels, source_pixels], unblurred mapping matrix.
    reconstruction
        Array of shape [source_pixels], solution vector.

    Returns
    -------
    ndarray
        The reconstructed data vector of shape [image_pixels].
    """
    return w_tilde @ (mapping_matrix @ reconstruction)


def reconstruction_positive_negative_from(
    data_vector: np.ndarray,
    curvature_reg_matrix: np.ndarray,
):
    """
    Solve the linear system [F + reg_coeff*H] S = D -> S = [F + reg_coeff*H]^-1 D given by equation (12)
    of https://arxiv.org/pdf/astro-ph/0302587.pdf

    S is the vector of reconstructed inversion values.

    This reconstruction uses a linear algebra solver that allows for negative and positives values in the solution.
    By allowing negative values, the solver is efficient, but there are many inference problems where negative values
    are nonphysical or undesirable.

    This function checks that the solution does not give a linear algebra error (e.g. because the input matrix is
    not positive-definitive).

    It also explicitly checks solutions where all reconstructed values go to the same value, and raises an exception if
    this occurs. This solution occurs in many scenarios when it is clear not a valid solution, and therefore is checked
    for and removed.

    Parameters
    ----------
    data_vector
        The `data_vector` D which is solved for.
    curvature_reg_matrix
        The sum of the curvature and regularization matrices.
    mapper_param_range_list
        A list of lists, where each list contains the range of values in the solution vector (reconstruction) that
        correspond to values that are part of a mapper's mesh.
    force_check_reconstruction
        If `True`, the reconstruction is forced to check for solutions where all reconstructed values go to the same
        value irrespective of the configuration file value.

    Returns
    -------
    curvature_reg_matrix
        The curvature_matrix plus regularization matrix, overwriting the curvature_matrix in memory.
    """
    return jnp.linalg.solve(curvature_reg_matrix, data_vector)


def reconstruction_positive_only_from(
    data_vector: np.ndarray,
    curvature_reg_matrix: np.ndarray,
):
    """
    Solve the linear system Eq.(2) (in terms of minimizing the quadratic value) of
    https://arxiv.org/pdf/astro-ph/0302587.pdf. Not finding the exact solution of Eq.(3) or Eq.(4).

    This reconstruction uses a linear algebra optimizer that allows only positives values in the solution.
    By not allowing negative values, the solver is slower than methods which allow negative values, but there are
    many inference problems where negative values are nonphysical or undesirable and removing them improves the solution.

    The non-negative optimizer we use is a modified version of fnnls (https://github.com/jvendrow/fnnls). The algorithm
    is published by:

    Bro & Jong (1997) ("A fast non‐negativity‐constrained least squares algorithm."
                Journal of Chemometrics: A Journal of the Chemometrics Society 11, no. 5 (1997): 393-401.)

    The modification we made here is that we create a function called fnnls_Cholesky which directly takes ZTZ and ZTx
    as inputs. The reason is that we realize for this specific algorithm (Bro & Jong (1997)), ZTZ and ZTx happen to
    be the curvature_reg_matrix and data_vector, respectively, already defined in PyAutoArray (verified). Besides,
    we build a Cholesky scheme that solves the lstsq problem in each iteration within the fnnls algorithm by updating
    the Cholesky factorisation.

    Please note that we are trying to find non-negative solution S that minimizes |Z * S - x|^2. We are not trying to
    find a solution that minimizes |ZTZ * S - ZTx|^2! ZTZ and ZTx are just some variables help to
    minimize |Z * S - x|^2. It is just a coincidence (or fundamentally not) that ZTZ and ZTx are the
    curvature_reg_matrix and data_vector, respectively.

    If we no longer uses fnnls (the algorithm of Bro & Jong (1997)), we need to check if the algorithm takes Z or
    ZTZ (x or ZTx) as an input. If not, we need to build Z and x in PyAutoArray.

    Parameters
    ----------
    data_vector
        The `data_vector` D happens to be the ZTx.
    curvature_reg_matrix
        The sum of the curvature and regularization matrices. Taken as ZTZ in our problem.
    settings
        Controls the settings of the inversion, for this function where the solution is checked to not be all
        the same values.\

    Returns
    -------
    Non-negative S that minimizes the Eq.(2) of https://arxiv.org/pdf/astro-ph/0302587.pdf.
    """
    import jaxnnls

    return jaxnnls.solve_nnls_primal(curvature_reg_matrix, data_vector)


def preconditioner_matrix_via_mapping_matrix_from(
    mapping_matrix: np.ndarray,
    regularization_matrix: np.ndarray,
    preconditioner_noise_normalization: float,
) -> np.ndarray:
    """
    Returns the preconditioner matrix `{` from a mapping matrix `f` and the sum of the inverse of the 1D noise-map
    values squared (see Powell et al. 2020).

    Parameters
    ----------
    mapping_matrix
        The matrix representing the mappings between sub-grid pixels and pixelization pixels.
    regularization_matrix
        The matrix defining how the pixelization's pixels are regularized with one another for smoothing (H).
    preconditioner_noise_normalization
        The sum of (1.0 / noise-map**2.0) every value in the noise-map.
    """

    curvature_matrix = curvature_matrix_via_mapping_matrix_from(
        mapping_matrix=mapping_matrix,
        noise_map=np.ones(shape=(mapping_matrix.shape[0])),
    )

    return (
        preconditioner_noise_normalization * curvature_matrix
    ) + regularization_matrix


def param_range_list_from(cls: Type, linear_obj_list) -> List[List[int]]:
    """
    Each linear object in the `Inversion` has N parameters, and these parameters correspond to a certain range
    of indexing values in the matrices used to perform the inversion.

    This function returns the `param_range_list` of an input type of linear object, which gives the indexing range
    of each linear object of the input type.

    For example, if an `Inversion` has:

    - A `LinearFuncList` linear object with 3 `params`.
    - A `Mapper` with 100 `params`.
    - A `Mapper` with 200 `params`.

    The corresponding matrices of this inversion (e.g. the `curvature_matrix`) have `shape=(303, 303)` where:

    - The `LinearFuncList` values are in the entries `[0:3]`.
    - The first `Mapper` values are in the entries `[3:103]`.
    - The second `Mapper` values are in the entries `[103:303]

    For this example, `param_range_list_from(cls=AbstractMapper)` therefore returns the
    list `[[3, 103], [103, 303]]`.

    Parameters
    ----------
    cls
        The type of class that the list of their parameter range index values are returned for.

    Returns
    -------
    A list of the index range of the parameters of each linear object in the inversion of the input cls type.
    """
    index_list = []

    pixel_count = 0

    for linear_obj in linear_obj_list:
        if isinstance(linear_obj, cls):
            index_list.append([pixel_count, pixel_count + linear_obj.params])

        pixel_count += linear_obj.params

    return index_list
