from inspect import signature
from math import sqrt

import numpy as np
import torch
import torch.nn as nn
from torch_geometric.data import Data
from torch_geometric.nn import GNNExplainer
from torch_geometric.utils import k_hop_subgraph, to_networkx


class GNNExplainerWrapper(torch.nn.Module):
    """
    Wrapper module for formatting GNN models into GNNExplainer
    applicable modules.
    """

    def __init__(self, model):
        super(GNNExplainerWrapper, self).__init__()
        self.model = model

    def forward(self, x, edge_index, edge_features=None, xbatch=None):

        res = self.model.forward(x, edge_index, edge_features, xbatch)

        return res["node_pred"][0]


class MetaLayerModelExplainer(GNNExplainer):
    """
    GNN Explainer reformatted to handle MetaLayerModels

    This implementation is largely a copy from Pytorch Geometric:
    https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/models/gnn_explainer.html
    © Copyright 2021, Matthias Fey. Revision 788a4c05.

    This patchwork is only temporary.
    """

    def __init__(self, model, **kwargs):
        super(MetaLayerModelExplainer, self).__init__(model, **kwargs)

    def visualize_subgraph(
        self,
        node_idx,
        edge_index,
        edge_mask,
        y=None,
        threshold=None,
        edge_y=None,
        node_alpha=None,
        seed=10,
        **kwargs,
    ):
        r"""Visualizes the subgraph given an edge mask
        :attr:`edge_mask`.
        Args:
            node_idx (int): The node id to explain.
                Set to :obj:`-1` to explain graph.
            edge_index (LongTensor): The edge indices.
            edge_mask (Tensor): The edge mask.
            y (Tensor, optional): The ground-truth node-prediction labels used
                as node colorings. All nodes will have the same color
                if :attr:`node_idx` is :obj:`-1`.(default: :obj:`None`).
            threshold (float, optional): Sets a threshold for visualizing
                important edges. If set to :obj:`None`, will visualize all
                edges with transparancy indicating the importance of edges.
                (default: :obj:`None`)
            edge_y (Tensor, optional): The edge labels used as edge colorings.
            node_alpha (Tensor, optional): Tensor of floats (0 - 1) indicating
                transparency of each node.
            seed (int, optional): Random seed of the :obj:`networkx` node
                placement algorithm. (default: :obj:`10`)
            **kwargs (optional): Additional arguments passed to
                :func:`nx.draw`.
        :rtype: :class:`matplotlib.axes.Axes`, :class:`networkx.DiGraph`
        """
        import matplotlib.pyplot as plt
        import networkx as nx

        assert edge_mask.size(0) == edge_index.size(1)

        if node_idx == -1:
            hard_edge_mask = torch.BoolTensor(
                [True] * edge_index.size(1), device=edge_mask.device
            )
            subset = torch.arange(edge_index.max().item() + 1, device=edge_index.device)
            y = None

        else:
            # Only operate on a k-hop subgraph around `node_idx`.
            subset, edge_index, _, hard_edge_mask = k_hop_subgraph(
                node_idx,
                self.num_hops,
                edge_index,
                relabel_nodes=True,
                num_nodes=None,
                flow=self.__flow__(),
            )

        edge_mask = edge_mask[hard_edge_mask]

        if threshold is not None:
            edge_mask = (edge_mask >= threshold).to(torch.float)

        if y is None:
            y = torch.zeros(edge_index.max().item() + 1, device=edge_index.device)
        else:
            y = y[subset].to(torch.float) / y.max().item()

        if edge_y is None:
            edge_color = ["black"] * edge_index.size(1)
        else:
            colors = list(plt.rcParams["axes.prop_cycle"])
            edge_color = [
                colors[i % len(colors)]["color"] for i in edge_y[hard_edge_mask]
            ]

        data = Data(
            edge_index=edge_index,
            att=edge_mask,
            edge_color=edge_color,
            y=y,
            num_nodes=y.size(0),
        ).to("cpu")
        G = to_networkx(data, node_attrs=["y"], edge_attrs=["att", "edge_color"])
        mapping = {k: i for k, i in enumerate(subset.tolist())}
        G = nx.relabel_nodes(G, mapping)

        node_args = set(signature(nx.draw_networkx_nodes).parameters.keys())
        node_kwargs = {k: v for k, v in kwargs.items() if k in node_args}
        node_kwargs["node_size"] = kwargs.get("node_size") or 800
        node_kwargs["cmap"] = kwargs.get("cmap") or "cool"

        label_args = set(signature(nx.draw_networkx_labels).parameters.keys())
        label_kwargs = {k: v for k, v in kwargs.items() if k in label_args}
        label_kwargs["font_size"] = kwargs.get("font_size") or 10

        pos = nx.spring_layout(G, seed=seed)
        ax = plt.gca()
        for source, target, data in G.edges(data=True):
            ax.annotate(
                "",
                xy=pos[target],
                xycoords="data",
                xytext=pos[source],
                textcoords="data",
                arrowprops=dict(
                    arrowstyle="->",
                    alpha=max(data["att"], 0.1),
                    color=data["edge_color"],
                    shrinkA=sqrt(node_kwargs["node_size"]) / 2.0,
                    shrinkB=sqrt(node_kwargs["node_size"]) / 2.0,
                    connectionstyle="arc3,rad=0.1",
                ),
            )

        if node_alpha is None:
            nx.draw_networkx_nodes(G, pos, node_color=y.tolist(), **node_kwargs)
        else:
            node_alpha_subset = node_alpha[subset]
            assert ((node_alpha_subset >= 0) & (node_alpha_subset <= 1)).all()
            nx.draw_networkx_nodes(
                G,
                pos,
                alpha=node_alpha_subset.tolist(),
                node_color=y.tolist(),
                **node_kwargs,
            )

        nx.draw_networkx_labels(G, pos, **label_kwargs)

        return ax, G
