# -*- coding: utf-8 -*-
from scipy.cluster.hierarchy import dendrogram, linkage, to_tree
from scipy.spatial.distance import squareform
import seaborn as sns
import pandas as pd
import numpy as np
import os
import json

from molmap.utils.logtools import print_info

# 可选使用 highcharts-core 做序列化；没有也可退化为 json.dumps
_USE_HC_CORE = False
try:
    from highcharts_core.converter import serialize_to_js_literal
    _USE_HC_CORE = True
except Exception:
    _USE_HC_CORE = False


def _write_highcharts_html(options: dict, out_path: str, width: int = 1000, height: int = 850, need_heatmap: bool = False):
    """把 Highcharts options 写成可直接打开的 HTML（使用官方 CDN）。"""
    container_id = "container"
    js_options = (
        serialize_to_js_literal(options, allow_vml=False, encoding='utf-8')
        if _USE_HC_CORE else json.dumps(options, ensure_ascii=False)
    )

    scripts = [
        "https://code.highcharts.com/highcharts.js",
        "https://code.highcharts.com/modules/exporting.js",
        "https://code.highcharts.com/modules/accessibility.js",
    ]
    if need_heatmap:
        scripts.append("https://code.highcharts.com/modules/heatmap.js")

    html = [
        "<!doctype html>",
        "<html><head><meta charset='utf-8'><title>Highcharts</title></head><body>",
        f"<div id='{container_id}' style='width:{width}px;height:{height}px;margin:0 auto;'></div>",
    ]
    for s in scripts:
        html.append(f"<script src='{s}'></script>")
    html.append("<script>")
    html.append("document.addEventListener('DOMContentLoaded', function(){")
    html.append(f"  Highcharts.chart('{container_id}', {js_options});")
    html.append("});")
    html.append("</script>")
    html.append("</body></html>")

    with open(out_path, "w", encoding="utf-8") as f:
        f.write("\n".join(html))


def plot_scatter(molmap, htmlpath='./', htmlname=None, radius=3):
    """
    2D scatter（每个 Subtype 一条 series）
    """
    title = '2D emmbedding of %s based on %s method' % (molmap.ftype, molmap.method)
    subtitle = 'number of %s: %s, metric method: %s' % (molmap.ftype, len(molmap.flist), molmap.metric)
    name = '%s_%s_%s_%s_%s' % (molmap.ftype, len(molmap.flist), molmap.metric, molmap.method, 'scatter')

    if not os.path.exists(htmlpath):
        os.makedirs(htmlpath)
    if htmlname:
        name = htmlname + '_' + name

    filename = os.path.join(htmlpath, name) + ".html"
    print_info('generate file: %s' % filename)

    # 准备数据
    xy = molmap.embedded.embedding_
    colormaps = molmap.extract.colormaps

    df = pd.DataFrame(xy, columns=['x', 'y'])
    bitsinfo = molmap.extract.bitsinfo.set_index('IDs')
    df = df.join(bitsinfo.loc[molmap.flist].reset_index())
    df['colors'] = df['Subtypes'].map(colormaps)

    # 组装 series
    series = []
    for subtype, color in colormaps.items():
        dfi = df[df['Subtypes'] == subtype]
        if dfi.empty:
            continue
        data = [{"x": float(r["x"]), "y": float(r["y"]), "IDs": str(r["IDs"])} for _, r in dfi.iterrows()]
        series.append({
            "type": "scatter",
            "name": str(subtype),
            "data": data,
            "color": color,
            "marker": {"radius": int(radius)},
        })

    options = {
        "chart": {"type": "scatter", "zoomType": "xy"},
        "title": {"text": title},
        "subtitle": {"text": subtitle},
        "xAxis": {
            "title": {"enabled": True, "text": "X", "style": {"fontSize": 20}},
            "labels": {"style": {"fontSize": 20}},
            "gridLineWidth": 1,
            "startOnTick": True,
            "endOnTick": True,
            "showLastLabel": True
        },
        "yAxis": {
            "title": {"text": "Y", "style": {"fontSize": 20}},
            "labels": {"style": {"fontSize": 20}},
            "gridLineWidth": 1
        },
        "legend": {
            "align": "right", "layout": "vertical",
            "margin": 1, "verticalAlign": "top", "y": 40,
            "symbolHeight": 12, "floating": False
        },
        "plotOptions": {
            "scatter": {
                "marker": {"states": {"hover": {"enabled": True, "lineColor": "rgb(100,100,100)"}}},
                "states": {"hover": {"marker": {"enabled": False}}},
                "tooltip": {"headerFormat": "<b>{series.name}</b><br>", "pointFormat": "{point.IDs}"}
            },
            "series": {
                "turboThreshold": 5000
            }
        },
        "series": series,
        "credits": {"enabled": False},
        "exporting": {"enabled": True}
    }

    _write_highcharts_html(options, filename, width=1000, height=850, need_heatmap=False)
    print_info('save html file to %s' % filename)
    return df, filename


def plot_grid(molmap, htmlpath='./', htmlname=None):
    """
    heatmap 网格（每个 Subtype 一条 heatmap series；每个点可指定 color）
    """
    if not os.path.exists(htmlpath):
        os.makedirs(htmlpath)

    title = 'Assignment of %s by %s emmbedding result' % (molmap.ftype, molmap.method)
    subtitle = 'number of %s: %s, metric method: %s' % (molmap.ftype, len(molmap.flist), molmap.metric)

    name = '%s_%s_%s_%s_%s' % (molmap.ftype, len(molmap.flist), molmap.metric, molmap.method, 'molmap')
    if htmlname:
        name = htmlname + '_' + name
    filename = os.path.join(htmlpath, name) + ".html"
    print_info('generate file: %s' % filename)

    m, n = molmap.fmap_shape
    colormaps = molmap.extract.colormaps
    position = np.zeros(molmap.fmap_shape, dtype='O').reshape(m * n,)
    position[molmap._S.col_asses] = molmap.flist
    position = position.reshape(m, n)

    # 扁平化坐标和值（与原实现保持一致的列主序）
    xs = np.repeat(np.arange(n), m)
    ys = np.tile(np.arange(m), n)
    vs = position.reshape(m * n, order='f')

    df = pd.DataFrame({"x": xs, "y": ys, "v": vs})
    bitsinfo = molmap.extract.bitsinfo
    subtypedict = bitsinfo.set_index('IDs')['Subtypes'].to_dict()
    subtypedict.update({0: 'NaN'})
    df['Subtypes'] = df['v'].map(subtypedict)
    df['colors'] = df['Subtypes'].map(colormaps)

    # 以 Subtype 分组，单独 series，给每个点带 color
    series = []
    for subtype, color in colormaps.items():
        dfi = df[df['Subtypes'] == subtype]
        if dfi.empty:
            continue
        data = [
            {"x": int(r["x"]), "y": int(r["y"]), "value": (0 if r["v"] == 0 else int(r["v"])), "v": r["v"], "color": color}
            for _, r in dfi.iterrows()
        ]
        series.append({
            "type": "heatmap",
            "name": str(subtype),
            "data": data,
            "borderWidth": 0
        })

    options = {
        "chart": {"type": "heatmap", "zoomType": "xy"},
        "title": {"text": title},
        "subtitle": {"text": subtitle},
        "xAxis": {
            "title": None,
            "min": 0, "max": n - 1,
            "startOnTick": False, "endOnTick": False,
            "allowDecimals": False,
            "labels": {"style": {"fontSize": 20}}
        },
        "yAxis": {
            "title": {"text": " ", "style": {"fontSize": 20}},
            "startOnTick": False, "endOnTick": False,
            "gridLineWidth": 0, "reversed": True,
            "min": 0, "max": m - 1,
            "allowDecimals": False,
            "labels": {"style": {"fontSize": 20}}
        },
        "legend": {
            "align": "right", "layout": "vertical",
            "margin": 1, "verticalAlign": "top",
            "y": 60, "symbolHeight": 12, "floating": False
        },
        "tooltip": {"headerFormat": "<b>{series.name}</b><br>", "pointFormat": "{point.v}"},
        "plotOptions": {
            "series": {
                "turboThreshold": 5000
            }
        },
        "colorAxis": {"min": 0},
        "series": series,
        "credits": {"enabled": False},
        "exporting": {"enabled": True}
    }

    _write_highcharts_html(options, filename, width=1000, height=850, need_heatmap=True)
    print_info('save html file to %s' % filename)
    return df, filename


def _getNewick(node, newick, parentdist, leaf_names):
    if node.is_leaf():
        return "%s:%.2f%s" % (leaf_names[node.id], parentdist - node.dist, newick)
    else:
        if len(newick) > 0:
            newick = "):%.2f%s" % (parentdist - node.dist, newick)
        else:
            newick = ");"
        newick = _getNewick(node.get_left(), newick, node.dist, leaf_names=leaf_names)
        newick = _getNewick(node.get_right(), ",%s" % (newick), node.dist, leaf_names=leaf_names)
        newick = "(%s" % (newick)
        return newick


def _mp2newick(molmap, treefile='mytree'):
    dist_matrix = molmap.dist_matrix
    leaf_names = molmap.flist
    df = molmap.df_embedding[['colors', 'Subtypes']]

    dists = squareform(dist_matrix)
    linkage_matrix = linkage(dists, 'complete')
    tree = to_tree(linkage_matrix, rd=False)
    newick = _getNewick(tree, "", tree.dist, leaf_names=leaf_names)

    with open(treefile + '.nwk', 'w') as f:
        f.write(newick)
    df.to_excel(treefile + '.xlsx')


def plot_tree(molmap, htmlpath='./', htmlname=None):
    # 保持与你原始文件一致：此处先占位
    pass
