# -*- coding: utf-8 -*-
from scipy.cluster.hierarchy import dendrogram, linkage, to_tree
from scipy.spatial.distance import squareform
import seaborn as sns
import pandas as pd
import numpy as np
import os
import json
from molmap.utils.logtools import print_info

# 可选：如果安装了 highcharts-core，则用它来做 options 的合法性校验/序列化。
# 没有也没关系，代码会自动回退到直接 json 序列化。
_USE_HC_CORE = False
try:
    from highcharts_core.converter import serialize_to_js_literal
    _USE_HC_CORE = True
except Exception:
    _USE_HC_CORE = False


def _write_highcharts_html(options: dict, out_path: str, width: int = 1000, height: int = 850, need_heatmap: bool = False):
    """把 Highcharts options 写成可直接打开的 HTML 文件（使用官方 CDN）。
    """
    container_id = "container"
    js_options = (
        serialize_to_js_literal(options, allow_vml=False, encoding='utf-8')
        if _USE_HC_CORE else json.dumps(options, ensure_ascii=False)
    )

    # 需要哪些模块就按需引入；scatter 用 core 即可，heatmap 需要额外模块
    scripts = [
        "https://code.highcharts.com/highcharts.js",
        "https://code.highcharts.com/modules/exporting.js",
        "https://code.highcharts.com/modules/accessibility.js",
    ]
    if need_heatmap:
        scripts.append("https://code.highcharts.com/modules/heatmap.js")

    html = [
        "<!doctype html>",
        "<html><head><meta charset='utf-8'><title>Highcharts</title></head><body>",
        f"<div id='{container_id}' style='width:{width}px;height:{height}px; margin:0 auto;'></div>"
    ]
    for s in scripts:
        html.append(f"<script src='{s}'></script>")
    html.append("<script>")
    html.append("document.addEventListener('DOMContentLoaded', function(){")
    html.append(f"  Highcharts.chart('{container_id}', {js_options});")
    html.append("});")
    html.append("</script>")
    html.append("</body></html>")

    with open(out_path, "w", encoding="utf-8") as f:
        f.write("\n".join(html))


def plot_scatter(mp, htmlpath='./', htmlname=None, radius=2, enabled_data_labels=False):
    """
    2D scatter（每个 Subtype 一条 series）
    """
    title = '2D emmbedding of %s based on %s method' % (mp.ftype, mp.emb_method)
    subtitle = 'number of %s: %s, metric method: %s' % (mp.ftype, len(mp.flist), mp.metric)
    name = '%s_%s_%s_%s_%s' % (mp.ftype, len(mp.flist), mp.metric, mp.emb_method, 'scatter')

    if not os.path.exists(htmlpath):
        os.makedirs(htmlpath)
    if htmlname:
        name = htmlname + '_' + name

    filename = os.path.join(htmlpath, name) + ".html"
    print_info('generate file: %s' % filename)

    # 准备数据
    xy = mp.embedded.embedding_
    colormaps = mp.colormaps

    df = pd.DataFrame(xy, columns=['x', 'y'])
    bitsinfo = mp.bitsinfo.set_index('IDs')
    df = df.join(bitsinfo.loc[mp.flist].reset_index())
    df['colors'] = df['Subtypes'].map(colormaps)

    # 组织 Highcharts options
    series = []
    for subtype, color in colormaps.items():
        dfi = df[df['Subtypes'] == subtype]
        if dfi.empty:
            continue
        data = [
            {"x": float(r["x"]), "y": float(r["y"]), "IDs": str(r["IDs"])}
            for _, r in dfi.iterrows()
        ]
        series.append({
            "type": "scatter",
            "name": str(subtype),
            "data": data,
            "color": color,
            "marker": {"radius": int(radius)}
        })

    options = {
        "chart": {"type": "scatter", "zoomType": "xy"},
        "title": {"text": title},
        "subtitle": {"text": subtitle},
        "xAxis": {
            "title": {"enabled": True, "text": "X", "style": {"fontSize": 20}},
            "labels": {"style": {"fontSize": 20}},
            "gridLineWidth": 1,
            "startOnTick": True,
            "endOnTick": True,
            "showLastLabel": True
        },
        "yAxis": {
            "title": {"text": "Y", "style": {"fontSize": 20}},
            "labels": {"style": {"fontSize": 20}},
            "gridLineWidth": 1
        },
        "legend": {
            "align": "right", "layout": "vertical",
            "margin": 1, "verticalAlign": "top", "y": 40,
            "symbolHeight": 12, "floating": False
        },
        "plotOptions": {
            "scatter": {
                "marker": {"states": {"hover": {"enabled": True, "lineColor": "rgb(100,100,100)"}}},
                "states": {"hover": {"marker": {"enabled": False}}},
                "tooltip": {"headerFormat": "<b>{series.name}</b><br>", "pointFormat": "{point.IDs}"}
            },
            "series": {
                "turboThreshold": 5000,
                "dataLabels": {"enabled": bool(enabled_data_labels), "format": "{point.IDs}"}
            }
        },
        "series": series,
        "credits": {"enabled": False},
        "exporting": {"enabled": True}
    }

    _write_highcharts_html(options, filename, width=1000, height=850, need_heatmap=False)
    print_info('save html file to %s' % filename)
    return df, filename


def plot_grid(mp, htmlpath='./', htmlname=None, enabled_data_labels=False):
    """
    heatmap 网格（每个 Subtype 一条 heatmap series；每个点可指定 color）
    """
    if not os.path.exists(htmlpath):
        os.makedirs(htmlpath)

    title = 'Assignment of %s by %s emmbedding result' % (mp.ftype, mp.emb_method)
    subtitle = 'number of %s: %s, metric method: %s' % (mp.ftype, len(mp.flist), mp.metric)

    name = '%s_%s_%s_%s_%s' % (mp.ftype, len(mp.flist), mp.metric, mp.emb_method, 'mp')
    if htmlname:
        name = htmlname + '_' + name
    filename = os.path.join(htmlpath, name) + ".html"
    print_info('generate file: %s' % filename)

    m, n = mp.fmap_shape
    colormaps = mp.colormaps
    position = np.zeros(mp.fmap_shape, dtype='O').reshape(m * n,)
    position[mp._S.col_asses] = mp.flist  # 把 flist 填回格子
    position = position.reshape(m, n)

    # 构造 x/y/v 扁平数据（注意 heatmap 默认用 value 字段）
    xs = np.repeat(np.arange(n), m)
    ys = np.tile(np.arange(m), n)
    vs = position.reshape(m * n, order='f')  # 保持与原实现一致

    df = pd.DataFrame({"x": xs, "y": ys, "v": vs})
    bitsinfo = mp.bitsinfo
    subtypedict = bitsinfo.set_index('IDs')['Subtypes'].to_dict()
    subtypedict.update({0: 'NaN'})
    df['Subtypes'] = df['v'].map(subtypedict)
    df['colors'] = df['Subtypes'].map(colormaps)

    # 以 Subtype 分组成多条 heatmap series（给每个点设 color）
    series = []
    for subtype, color in colormaps.items():
        dfi = df[df['Subtypes'] == subtype]
        if dfi.empty:
            continue
        data = [
            {"x": int(r["x"]), "y": int(r["y"]), "value": (0 if r["v"] == 0 else int(r["v"])), "v": r["v"], "color": color}
            for _, r in dfi.iterrows()
        ]
        series.append({
            "type": "heatmap",
            "name": str(subtype),
            "data": data,
            "borderWidth": 0
        })

    options = {
        "chart": {"type": "heatmap", "zoomType": "xy"},
        "title": {"text": title},
        "subtitle": {"text": subtitle},
        "xAxis": {
            "title": None,
            "min": 0, "max": n - 1,
            "startOnTick": False, "endOnTick": False,
            "allowDecimals": False,
            "labels": {"style": {"fontSize": 20}}
        },
        "yAxis": {
            "title": {"text": " ", "style": {"fontSize": 20}},
            "startOnTick": False, "endOnTick": False,
            "gridLineWidth": 0, "reversed": True,
            "min": 0, "max": m - 1,
            "allowDecimals": False,
            "labels": {"style": {"fontSize": 20}}
        },
        "legend": {
            "align": "right", "layout": "vertical", "margin": 1,
            "verticalAlign": "top", "y": 60, "symbolHeight": 12, "floating": False
        },
        "tooltip": {"headerFormat": "<b>{series.name}</b><br>", "pointFormat": "{point.v}"},
        "plotOptions": {
            "series": {
                "turboThreshold": 5000,
                "dataLabels": {
                    "enabled": bool(enabled_data_labels),
                    "format": "{point.v}",
                    "style": {"textOutline": False, "color": "black"}
                }
            }
        },
        "colorAxis": {"min": 0},  # 有 value 字段时必须给个 colorAxis
        "series": series,
        "credits": {"enabled": False},
        "exporting": {"enabled": True}
    }

    _write_highcharts_html(options, filename, width=1000, height=850, need_heatmap=True)
    print_info('save html file to %s' % filename)
    return df, filename


# ====== 树的 newick 导出（修正原代码里的函数名不一致问题） ======
def _getNewick(node, newick, parentdist, leaf_names):
    if node.is_leaf():
        return "%s:%.2f%s" % (leaf_names[node.id], parentdist - node.dist, newick)
    else:
        if len(newick) > 0:
            newick = "):%.2f%s" % (parentdist - node.dist, newick)
        else:
            newick = ");"
        newick = _getNewick(node.get_left(), newick, node.dist, leaf_names=leaf_names)
        newick = _getNewick(node.get_right(), ",%s" % (newick), node.dist, leaf_names=leaf_names)
        newick = "(%s" % (newick)
        return newick


def _mp2newick(mp, treefile='mytree'):
    dist_matrix = mp.dist_matrix
    leaf_names = mp.flist
    df = mp.df_embedding[['colors', 'Subtypes']]

    dists = squareform(dist_matrix)
    linkage_matrix = linkage(dists, 'complete')
    tree = to_tree(linkage_matrix, rd=False)
    newick = _getNewick(tree, "", tree.dist, leaf_names=leaf_names)

    with open(treefile + '.nwk', 'w') as f:
        f.write(newick)
    df.to_excel(treefile + '.xlsx')


def plot_tree(mp, htmlpath='./', htmlname=None):
    # 这里原来就是 pass；如果后续需要渲染 dendrogram，可再补
    pass
