import marsilea as ma
import marsilea.plotter as mp
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from anndata import AnnData
from matplotlib.figure import Figure

from decoupler._docs import docs
from decoupler._Plotter import Plotter


def _input(
    adata: AnnData,
    uns_key: str,
    names: str | list | None = None,
    nvar: int | str | list | None = 10,
) -> tuple[pd.DataFrame, pd.DataFrame, list]:
    assert isinstance(adata, AnnData), "adata must be adata.AnnData"
    assert isinstance(uns_key, str) and uns_key in adata.uns, "uns_key must be str and in adata.uns"
    assert isinstance(names, str | list) or names is None, "names must be str, list or None"
    assert isinstance(nvar, int | float | str | list) or nvar is None, "nvar must be numeric, list or None"
    # Filter stats by obs names
    stats = adata.uns[uns_key]
    assert hasattr(stats, "key"), "adata.uns[key] must be generated by decoupler.tl.rankby_obsm"
    obsm_key = stats.key
    stats = stats.sort_values("obsm")
    if isinstance(names, str):
        names_lst = [names]
    elif names is None:
        names_lst = []
    else:
        names_lst = names
    if isinstance(nvar, str):
        nvar = [nvar]
    if names_lst:
        stats = stats[stats["obs"].isin(names_lst)]
    # Filter stats by obsm nvar
    obsm = adata.obsm[obsm_key]
    if isinstance(obsm, pd.DataFrame):
        var_names = obsm.std(ddof=1, axis=0).sort_values(ascending=False).index
        if isinstance(nvar, int | float):
            var_names = var_names[: int(nvar)]
        elif isinstance(nvar, list):
            var_names = nvar
        obsm = obsm.loc[:, var_names]
        stats = stats[stats["obsm"].isin(var_names)]
    else:
        var_names = sorted(stats["obsm"].unique())
        if isinstance(nvar, int | float):
            var_names = var_names[:nvar]
            obsm = obsm[:, :nvar]
        elif isinstance(nvar, list):
            idx = np.searchsorted(var_names, nvar)
            obsm = obsm[:, idx]
            var_names = nvar
        stats = stats[stats["obsm"].isin(var_names)]
    # Extract obsm
    obsm = pd.DataFrame(obsm, columns=var_names)
    # Transform stats
    min_p = stats[stats["padj"] > 0]["padj"].min()
    stats.loc[stats["padj"] == 0, "padj"] = min_p
    stats["padj"] = -np.log10(stats["padj"])
    stats = stats.pivot(index="obs", columns="obsm", values="padj")
    stats.index.name = None
    stats.columns.name = None
    if names is None:
        names_lst = list(stats.index)
    return obsm, stats, names_lst


@docs.dedent
def obsm(
    adata: AnnData,
    key: str = "rank_obsm",
    names: str | list | None = None,
    nvar: int | str | list | None = 10,
    dendrogram: bool = True,
    thr_sign: float = 0.05,
    titles: list | None = None,
    cmap_stat: str = "Purples",
    cmap_obsm: str = "BrBG",
    cmap_obs: dict | None = None,
    **kwargs,
) -> None | Figure:
    """
    Plot metadata associations with features in ``adata.obsm``.

    Parameters
    ----------
    %(adata)s
    key
        Name of ``adata.uns`` key storing ``decoupler.tl.rank_obsm`` results.
    names
        Which metadata covariates to show.
    nvar
        How many features from ``adata.obsm`` to show.
    dendogram
        Whether to sort and plot samples using a dendogram.
    thr_sign
        Threshold of significance for the adjusted p-values.
    titles
        List of titles to place for the metadata heatmap and ``obsm`` features.
    cmap_stat
        Colormap for metadata statistics.
    cmap_obsm
        Colormap for ``obsm`` features.
    cmap_obs
        Dictionary of colormaps containing a palette for each metadata covariate being plotted.
    %(plot)s

    Example
    -------
    .. code-block:: python

        import decoupler as dc
        import scanpy as sc

        adata, net = dc.ds.toy()
        sc.pp.scale(adata)
        sc.tl.pca(adata)
        dc.tl.rankby_obsm(adata, key="X_pca")
        dc.pl.obsm(adata=adata, nvar=5)
    """
    # Validate
    assert isinstance(dendrogram, bool), "dendrogram must be bool"
    assert isinstance(thr_sign, float) and 1 >= thr_sign >= 0, "thr_sign must be float and between 0 and 1"
    if titles is None:
        titles = ["Scores", "Stats"]
    assert isinstance(titles, list) and len(titles) == 2, "titles must be list and with 2 elements"
    if cmap_obs is None:
        cmap_obs = {}
    assert isinstance(cmap_obs, dict), "cmap_obs must be dict"
    # Extract
    obsm, stats, names = _input(adata=adata, uns_key=key, names=names, nvar=nvar)
    # Instance
    kwargs["ax"] = None
    bp = Plotter(**kwargs)
    bp.fig.delaxes(bp.ax)
    plt.close(bp.fig)
    # Plot stats
    h1 = ma.Heatmap(stats, cmap=cmap_stat, name="h1", width=4, height=1, label=r"$-\log_{10}(padj)$")
    h1.add_title(top=titles[1], align="center")
    if dendrogram:
        h1.add_dendrogram("left")
    h1.add_right(
        mp.Labels(
            stats.index,
        )
    )
    sign_msk = stats.values > -np.log10(thr_sign)
    layer = mp.MarkerMesh(sign_msk, marker="*", label=f"padj < {thr_sign}", color="red")
    h1.add_layer(layer, name="sign")
    # Plot obsm
    h2 = ma.Heatmap(obsm, cmap=cmap_obsm, name="h2", width=0.4, height=4, label="Score")
    h2.add_title(top=titles[0], align="center")
    if dendrogram:
        h2.add_dendrogram("left")
    h2.add_bottom(mp.Labels(obsm.columns))
    # Add obs legends
    for name in names:
        is_numeric = pd.api.types.is_numeric_dtype(adata.obs[name])
        if is_numeric:
            if name not in cmap_obs:
                cmap = "viridis"
            else:
                cmap = cmap_obs[name]
            colors = mp.ColorMesh(adata.obs[name], cmap=cmap, label=name)
        else:
            cats = adata.obs[name].sort_values().unique()
            if name not in cmap_obs:
                tab10 = plt.get_cmap("tab10")
                palette = {k: tab10(i) for i, k in enumerate(cats)}
            else:
                c_cmap = plt.get_cmap(cmap_obs[name])
                palette = {k: c_cmap(i) for i, k in enumerate(cats)}
            colors = mp.Colors(adata.obs[name], palette=palette, label=name)
        h2.add_right(colors, pad=0.1, size=0.1)
    # Build plot
    c = h1 / 0.05 / h2
    c.add_legends(side="right", stack_by="row", stack_size=3, align_legends="top")
    c.render()
    if bp.return_fig or bp.save is not None:
        plt.close()
    # Add borders
    hax = c.get_ax(board_name="h1", ax_name="h1")
    border = matplotlib.patches.Rectangle((0, 0), 1, 1, fill=False, ec=".1", lw=2, transform=hax.transAxes)
    hax.add_artist(border)
    hax = c.get_ax(board_name="h2", ax_name="h2")
    border = matplotlib.patches.Rectangle((0, 0), 1, 1, fill=False, ec=".1", lw=2, transform=hax.transAxes)
    hax.add_artist(border)
    bp.fig = c.figure
    bp.fig.set_figwidth(bp.figsize[0])
    bp.fig.set_figheight(bp.figsize[1])
    bp.fig.set_dpi(bp.dpi)
    return bp._return()
