import pandas as pd
import scanpy
from collections import defaultdict
import numpy as np
from scipy.optimize import nnls
import scanpy as sc
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
from ..utils import pyomic_palette
import matplotlib
from typing import Union,Tuple
import anndata

def bulk2single_data_prepare(bulk_data:pd.DataFrame,
                             single_data:anndata.AnnData,celltype_key:str)->dict:
    r"""
    Prepare and align data for bulk-to-single-cell deconvolution analysis.
    
    Formats and aligns bulk RNA-seq and single-cell data by finding common genes
    and organizing metadata for downstream bulk2single analysis.

    Arguments:
        bulk_data: Bulk RNA-seq expression data with genes as rows and samples as columns
        single_data: Single-cell RNA-seq reference data as AnnData object
        celltype_key: Column name in single_data.obs containing cell type annotations

    Returns:
        dict: Dictionary containing aligned data matrices and metadata for bulk2single analysis
    """
    print("...loading data")
    input_data = {}
    
    ism=pd.DataFrame(index=single_data.obs.index)
    ism['Cell']=single_data.obs.index
    ism['Cell_type']=single_data.obs[celltype_key].values
    input_data["input_sc_meta"] = ism
    
    input_data["sc_gene"] = single_data.var.index.values.tolist()
    input_data["bulk_gene"] = bulk_data.index.values.tolist()
    
    bulk_genes=input_data["bulk_gene"]
    intersection_genes=[]
    for i in input_data["sc_gene"]:
        if i in bulk_genes:
            intersection_genes.append(i)
            
    input_data["intersect_gene"] = intersection_genes
    input_data["input_sc_data"] = single_data[:,input_data["intersect_gene"]].to_df().T
    input_data["input_bulk"] = bulk_data.loc[input_data["intersect_gene"]]
    

    return input_data

def bulk2single_plot_cellprop(generate_single_data:anndata.AnnData,
                              celltype_key:str,figsize:tuple=(4,4))->Tuple[matplotlib.figure.Figure,matplotlib.axes._axes.Axes]:
    r"""
    Plot cell-type proportions in generated single-cell data.
    
    Visualizes the distribution of cell types in data generated by bulk2single
    analysis as a bar plot with cell counts per type.

    Arguments:
        generate_single_data: Generated single-cell data from bulk2single analysis
        celltype_key: Column name in obs containing cell type annotations
        figsize: Figure dimensions as (width, height) (4, 4)

    Returns:
        matplotlib.axes.Axes: Axes object containing the cell proportion bar plot
    """
    ct_stat = pd.DataFrame(generate_single_data.obs[celltype_key].value_counts())
    generate_single_data.obs[celltype_key]=generate_single_data.obs[celltype_key].astype('category')
    key_name=list(generate_single_data.obs[celltype_key].cat.categories)
    ct_name = list(ct_stat.index)
    ct_num = list(ct_stat.iloc[:, 0])
    if '{}_colors'.format(celltype_key) in generate_single_data.uns.keys():
        color=generate_single_data.uns['{}_colors'.format(celltype_key)]
        color_dict=dict(zip(key_name,color))
        color=[color_dict[i] for i in ct_name]
    else:
        color = pyomic_palette()
    fig, ax = plt.subplots(figsize=figsize)
    plt.bar(ct_name, ct_num, color=color)
    plt.xticks(ct_name, ct_name, rotation=90)
    plt.title("The number of cells per cell type in bulk-seq data")
    plt.xlabel("Cell type")
    plt.ylabel("Cell number")
    return ax

def bulk2single_plot_correlation(single_data:anndata.AnnData,
                                 generate_single_data:anndata.AnnData,
                                 celltype_key:str,return_table:bool=False,
                                 figsize:tuple=(6,6),cmap:str='RdBu_r')->Tuple[matplotlib.figure.Figure,matplotlib.axes._axes.Axes]:
    r"""
    Plot correlation matrix between reference and generated single-cell data.
    
    Compares expression patterns of cell types between original single-cell reference
    and bulk2single generated data using marker gene correlations.

    Arguments:
        single_data: Original single-cell reference data
        generate_single_data: Generated single-cell data from bulk2single analysis
        celltype_key: Column name in obs containing cell type annotations
        return_table: Whether to return correlation matrix instead of plot (False)
        figsize: Figure dimensions as (width, height) (6, 6)
        cmap: Colormap for correlation heatmap ('RdBu_r')

    Returns:
        matplotlib.figure.Figure: Figure object containing correlation heatmap
        matplotlib.axes.Axes: Axes object for the heatmap
    """

    # Calculate 200 marker genes of each cell type
    sc.tl.rank_genes_groups(single_data, celltype_key, method='wilcoxon')
    marker_df = pd.DataFrame(single_data.uns['rank_genes_groups']['names']).head(200)
    #marker = list(set(np.unique(np.ravel(np.array(marker_df))))&set(generate_adata.var.index.tolist()))
    marker = list(set(np.unique(np.ravel(np.array(marker_df))))&set(generate_single_data.var.index.tolist()))

    # the mean expression of 200 marker genes of input sc data
    sc_marker = single_data[:,marker].to_df()
    sc_marker[celltype_key] = single_data.obs[celltype_key]
    sc_marker_mean = sc_marker.groupby(celltype_key)[marker].mean()
    
    # the mean expression of 200 marker genes of deconvoluted bulk-seq data
    #generate_sc_meta.index = list(generate_sc_meta['Cell'])
    generate_sc_data_new = generate_single_data[:,marker].to_df()
    generate_sc_data_new[celltype_key] = generate_single_data.obs[celltype_key]
    generate_sc_marker_mean = generate_sc_data_new.groupby(celltype_key)[marker].mean()

    intersect_cell = list(set(sc_marker_mean.index).intersection(set(generate_sc_marker_mean.index)))
    generate_sc_marker_mean= generate_sc_marker_mean.loc[intersect_cell]
    sc_marker_mean= sc_marker_mean.loc[intersect_cell]

    # calculate correlation
    sc_marker_mean = sc_marker_mean.T
    generate_sc_marker_mean = generate_sc_marker_mean.T

    coeffmat = np.zeros((sc_marker_mean.shape[1], generate_sc_marker_mean.shape[1]))
    for i in range(sc_marker_mean.shape[1]):    
        for j in range(generate_sc_marker_mean.shape[1]):        
            corrtest = pearsonr(sc_marker_mean[sc_marker_mean.columns[i]], 
                                generate_sc_marker_mean[generate_sc_marker_mean.columns[j]])  
            coeffmat[i,j] = corrtest[0]
    if return_table==True:
        return coeffmat
    rf_ct = list(sc_marker_mean.columns)
    generate_ct = list(generate_sc_marker_mean.columns)

    fig, ax = plt.subplots(figsize=figsize)
    im = ax.imshow(coeffmat, cmap=cmap)
    ax.set_xticks(np.arange(len(rf_ct)))
    ax.set_xticklabels(rf_ct)
    ax.set_yticks(np.arange(len(generate_ct)))
    ax.set_yticklabels(generate_ct)
    plt.xlabel("scRNA-seq reference")
    plt.ylabel("deconvoluted bulk-seq")
    plt.setp(ax.get_xticklabels(), rotation=90, ha="right", rotation_mode="anchor")
    plt.colorbar(im)
    ax.set_title("Expression correlation")
    fig.tight_layout()
    return fig,ax



def load_data(input_bulk_path,
              input_sc_data_path,
              input_sc_meta_path,):
    r"""
    Load bulk and single-cell data from CSV files.
    
    Loads bulk RNA-seq and single-cell data from file paths and prepares
    them for bulk2single analysis by finding common genes.

    Arguments:
        input_bulk_path: Path to bulk RNA-seq CSV file (genes as rows)
        input_sc_data_path: Path to single-cell expression CSV file (genes as rows)
        input_sc_meta_path: Path to single-cell metadata CSV file

    Returns:
        dict: Dictionary containing loaded and aligned data matrices
    """
    input_sc_meta_path = input_sc_meta_path
    input_sc_data_path = input_sc_data_path
    input_bulk_path = input_bulk_path
    print("loading data......")
    input_data = {}
    # load sc_meta.csv file, containing two columns of cell name and cell type
    input_data["input_sc_meta"] = pd.read_csv(input_sc_meta_path, index_col=0)
    # load sc_data.csv file, containing gene expression of each cell
    input_sc_data = pd.read_csv(input_sc_data_path, index_col=0)
    input_data["sc_gene"] = input_sc_data.index.values.tolist()
    # load bulk.csv file, containing one column of gene expression in bulk
    input_bulk = pd.read_csv(input_bulk_path, index_col=0)
    input_data["bulk_gene"] = input_bulk.index.values.tolist()
    # filter overlapping genes.
    bulk_genes=input_data["bulk_gene"]
    intersection_genes=[]
    for i in input_data["sc_gene"]:
        if i in bulk_genes:
            intersection_genes.append(i)
    
    input_data["intersect_gene"] = intersection_genes
    input_data["input_sc_data"] = input_sc_data.loc[input_data["intersect_gene"]]
    input_data["input_bulk"] = input_bulk.loc[input_data["intersect_gene"]]
    # load st_meta.csv and st_data.csv, containing coordinates and gene expression of each spot respectively.
    #input_data["input_st_meta"] = pd.read_csv(input_st_meta_path, index_col=0)
    #input_data["input_st_data"] = pd.read_csv(input_st_data_path, index_col=0)
    print("load data done!")
    
    return input_data


def data_process(data, top_marker_num, ratio_num):
    r"""
    Process data and calculate cell-type target numbers for generation.
    
    Identifies marker genes, performs deconvolution to estimate cell-type
    proportions, and calculates target cell numbers for each type.

    Arguments:
        data: Dictionary containing aligned bulk and single-cell data
        top_marker_num: Number of top marker genes to select per cell type
        ratio_num: Multiplier for scaling target cell numbers

    Returns:
        dict: Dictionary mapping cell types to target generation numbers
    """
    # marker used
    sc = scanpy.AnnData(data["input_sc_data"].T)
    sc.obs = data["input_sc_meta"][['Cell_type']]
    scanpy.tl.rank_genes_groups(sc, 'Cell_type', method='wilcoxon')
    marker_df = pd.DataFrame(sc.uns['rank_genes_groups']['names']).head(top_marker_num)
    marker_array = np.array(marker_df)
    marker_array = np.ravel(marker_array)
    marker_array = np.unique(marker_array)
    marker = list(marker_array)
    sc_marker = data["input_sc_data"].loc[marker, :]
    bulk_marker = data["input_bulk"].loc[marker]

    #  Data processing
    breed = data["input_sc_meta"]['Cell_type']
    breed_np = breed.values
    breed_set = set(breed_np)
    id2label = sorted(list(breed_set))  # List of breed
    label2id = {label: idx for idx, label in enumerate(id2label)}  # map breed to breed-id

    cell2label = dict()  # map cell-name to breed-id
    label2cell = defaultdict(set)  # map breed-id to cell-names
    for row in data["input_sc_meta"].itertuples():
        cell_name = getattr(row, 'Cell')
        cell_type = label2id[getattr(row, 'Cell_type')]
        cell2label[cell_name] = cell_type
        label2cell[cell_type].add(cell_name)

    label_devide_data = dict()
    for label, cells in label2cell.items():
        label_devide_data[label] = sc_marker[list(cells)]

    single_cell_splitby_breed_np = {}
    for key in label_devide_data.keys():
        single_cell_splitby_breed_np[key] = label_devide_data[key].values  # [gene_num, cell_num]
        single_cell_splitby_breed_np[key] = single_cell_splitby_breed_np[key].mean(axis=1)

    max_decade = len(single_cell_splitby_breed_np.keys())
    single_cell_matrix = []

    for i in range(max_decade):
        single_cell_matrix.append(single_cell_splitby_breed_np[i].tolist())

    single_cell_matrix = np.array(single_cell_matrix)
    single_cell_matrix = np.transpose(single_cell_matrix)  # (gene_num, label_num)

    bulk_marker = bulk_marker.values  # (gene_num, 1)
    bulk_rep = bulk_marker.reshape(bulk_marker.shape[0], )

    # calculate celltype ratio in each spot by NNLS
    ratio = nnls(single_cell_matrix, bulk_rep)[0]
    ratio = ratio / sum(ratio)

    ratio_array = np.round(ratio * data["input_sc_meta"].shape[0] * ratio_num)
    ratio_list = [r for r in ratio_array]

    cell_target_num = dict(zip(id2label, ratio_list))

    return cell_target_num



