import json
import numpy as np
from networkx import Graph, dijkstra_path, exception

from .topology import Topology as Topo
from ..kernel.timeline import Timeline
from .node import BSMNode
from ..constants import SPEED_OF_LIGHT
from typing import Dict, List, Type
from .node import Node, DQCNode


class DQCNetTopo(Topo):
    """Class for generating distributed quantum computing network with distributed quantum computing nodes

    Class DQCNodeNetTopo is the child of class Topology. Distributed quantum computing nodes, BSM
    nodes, quantum channels, classical channels and timeline for simulation
    could be generated by using this class.

    Attributes:
        bsm_to_router_map (dict[str, list[Node]]): mapping of bsm node to two connected routers
        nodes (dict[str, list[Node]]): mapping of type of node to a list of same type node.
        qchannels (list[QuantumChannel]): list of quantum channel objects in network.
        cchannels (list[ClassicalChannel]): list of classical channel objects in network.
        tl (Timeline): the timeline used for simulation
    """
    BSM_NODE = "BSMNode"
    MEET_IN_THE_MID = "meet_in_the_middle"
    MEMO_ARRAY_SIZE = "memo_size"             # communication memories
    CONTROLLER = "Controller"
    DQC_NODE = "DQCNode"
    DATA_MEMO_ARRAY_SIZE = "data_memo_size"   # data memories

    def __init__(self, conf_file_name: str):
        self.bsm_to_router_map = {}
        self.encoding_type = None
        super().__init__(conf_file_name)

    def _load(self, filename: str):
        with open(filename) as fh:
            config = json.load(fh)

        self._get_templates(config)
        # quantum connections are only supported by sequential simulation so far
        self._add_qconnections(config)
        self._add_timeline(config)
        self._map_bsm_routers(config)
        self._add_nodes(config)
        self._add_bsm_node_to_router()
        self._add_qchannels(config)
        self._add_cchannels(config)
        self._add_cconnections(config)
        self._generate_forwarding_table(config)

    def _add_timeline(self, config: dict):
        stop_time = config.get(Topo.STOP_TIME, float('inf'))
        self.tl = Timeline(stop_time)

    def _map_bsm_routers(self, config):
        for qc in config[Topo.ALL_Q_CHANNEL]:
            src, dst = qc[Topo.SRC], qc[Topo.DST]
            if dst in self.bsm_to_router_map:
                self.bsm_to_router_map[dst].append(src)
            else:
                self.bsm_to_router_map[dst] = [src]

    def _add_nodes(self, config: dict):
        for node in config[Topo.ALL_NODE]:
            seed = node[Topo.SEED]
            node_type = node[Topo.TYPE]
            name = node[Topo.NAME]
            template_name = node.get(Topo.TEMPLATE, None)
            template = self.templates.get(template_name, {})

            if node_type == self.BSM_NODE:
                others = self.bsm_to_router_map[name]
                node_obj = BSMNode(name, self.tl, others, component_templates=template)
            elif node_type == self.DQC_NODE:
                data_size = node.get(self.DATA_MEMO_ARRAY_SIZE, 0)
                comm_size = node.get(self.MEMO_ARRAY_SIZE, 0)
                node_obj = DQCNode(name, self.tl, memo_size=comm_size, data_memo_size=data_size, component_templates=template)
            else:
                raise ValueError(f"Unknown type of node '{node_type}'")

            node_obj.set_seed(seed)
            self.nodes[node_type].append(node_obj)

    def _add_bsm_node_to_router(self):
        for bsm in self.bsm_to_router_map:
            r0_str, r1_str = self.bsm_to_router_map[bsm]
            r0 = self.tl.get_entity_by_name(r0_str)
            r1 = self.tl.get_entity_by_name(r1_str)
            if r0 is not None:
                r0.add_bsm_node(bsm, r1_str)
            if r1 is not None:
                r1.add_bsm_node(bsm, r0_str)

    def _add_qconnections(self, config: dict):
        """generate bsm_info, qc_info, and cc_info for the q_connections."""
        for q_connect in config.get(Topo.ALL_Q_CONNECT, []):
            node1 = q_connect[Topo.CONNECT_NODE_1]
            node2 = q_connect[Topo.CONNECT_NODE_2]
            attenuation = q_connect[Topo.ATTENUATION]
            distance = q_connect[Topo.DISTANCE] // 2
            channel_type = q_connect[Topo.TYPE]
            cc_delay = []                                   # generate classical channel delay
            for cc in config.get(self.ALL_C_CHANNEL, []):   # classical channel
                if cc[self.SRC] == node1 and cc[self.DST] == node2:
                    delay = cc.get(self.DELAY, cc.get(self.DISTANCE, 1000) / SPEED_OF_LIGHT)
                    cc_delay.append(delay)
                elif cc[self.SRC] == node2 and cc[self.DST] == node1:
                    delay = cc.get(self.DELAY, cc.get(self.DISTANCE, 1000) / SPEED_OF_LIGHT)
                    cc_delay.append(delay)

            for cc in config.get(self.ALL_C_CONNECT, []):  # classical connection
                if (cc[self.CONNECT_NODE_1] == node1 and cc[self.CONNECT_NODE_2] == node2) \
                        or (cc[self.CONNECT_NODE_1] == node2 and cc[self.CONNECT_NODE_2] == node1):
                    delay = cc.get(self.DELAY, cc.get(self.DISTANCE, 1000) / SPEED_OF_LIGHT)
                    cc_delay.append(delay)
            if len(cc_delay) == 0:
                assert 0, q_connect
            cc_delay = np.mean(cc_delay) // 2

            if channel_type == self.MEET_IN_THE_MID:
                bsm_name = f"BSM.{node1}.{node2}.auto"  # the intermediate BSM node
                bsm_seed = q_connect.get(Topo.SEED, 0)
                bsm_template_name = q_connect.get(Topo.TEMPLATE, None)
                bsm_info = {self.NAME: bsm_name,
                            self.TYPE: self.BSM_NODE,
                            self.SEED: bsm_seed,
                            self.TEMPLATE: bsm_template_name}
                config[self.ALL_NODE].append(bsm_info)

                for src in [node1, node2]:
                    qc_name = f"QC.{src}.{bsm_name}"  # the quantum channel
                    qc_info = {self.NAME: qc_name,
                               self.SRC: src,
                               self.DST: bsm_name,
                               self.DISTANCE: distance,
                               self.ATTENUATION: attenuation}
                    if self.ALL_Q_CHANNEL not in config:
                        config[self.ALL_Q_CHANNEL] = []
                    config[self.ALL_Q_CHANNEL].append(qc_info)

                    cc_name = f"CC.{src}.{bsm_name}"  # the classical channel
                    cc_info = {self.NAME: cc_name,
                               self.SRC: src,
                               self.DST: bsm_name,
                               self.DISTANCE: distance,
                               self.DELAY: cc_delay}
                    if self.ALL_C_CHANNEL not in config:
                        config[self.ALL_C_CHANNEL] = []
                    config[self.ALL_C_CHANNEL].append(cc_info)

                    cc_name = f"CC.{bsm_name}.{src}"
                    cc_info = {self.NAME: cc_name,
                               self.SRC: bsm_name,
                               self.DST: src,
                               self.DISTANCE: distance,
                               self.DELAY: cc_delay}
                    config[self.ALL_C_CHANNEL].append(cc_info)
            else:
                raise NotImplementedError("Unknown type of quantum connection")

    def _generate_forwarding_table(self, config: dict):
        """For static routing."""
        graph = Graph()
        for node in config[Topo.ALL_NODE]:
            if node[Topo.TYPE] == self.DQC_NODE:
                graph.add_node(node[Topo.NAME])

        costs = {}
        for qc in self.qchannels:
            router, bsm = qc.sender.name, qc.receiver
            if bsm not in costs:
                costs[bsm] = [router, qc.distance]
            else:
                costs[bsm] = [router] + costs[bsm]
                costs[bsm][-1] += qc.distance

        graph.add_weighted_edges_from(costs.values())
        for src in self.nodes[self.DQC_NODE]:
            for dst_name in graph.nodes:
                if src.name == dst_name:
                    continue
                try:
                    if dst_name > src.name:
                        path = dijkstra_path(graph, src.name, dst_name)
                    else:
                        path = dijkstra_path(graph, dst_name, src.name)[::-1]
                    next_hop = path[1]
                    # routing protocol locates at the bottom of the stack
                    routing_protocol = src.network_manager.protocol_stack[0]  # guarantee that [0] is the routing protocol?
                    routing_protocol.add_forwarding_rule(dst_name, next_hop)
                except exception.NetworkXNoPath:
                    pass

    def infer_qubit_to_node(self, total_wires: int) -> dict[int, str]:
        """Auto-infer the {wire_index: node_name} map by 
           first assigning every node's n_data qubits in JSON order, then every node's n_ancilla qubits.
        
        Args:
            total_wires (int): The total number of wires (qubits) in the system.
        Return:
            dict[int, str]: A mapping from wire indices to node names.
        """
        mapping: dict[int, str] = {}
        next_wire = 0
        # 1) data wires
        for nd in self._raw_cfg["nodes"]:
            name   = nd["name"]
            n_data = nd.get("n_data", 1)
            for _ in range(n_data):
                if next_wire >= total_wires:
                    raise ValueError(f"Mapping overflow: more data qubits than {total_wires}")
                mapping[next_wire] = name
                next_wire += 1
        # 3) (optionally) any communication‐only qubits, etc.
        #    If your circuit has exactly data+ancilla qubits, you can assert:
        if next_wire != total_wires:
            raise ValueError(f"Configured for {next_wire} wires but circuit has {total_wires}")
        return mapping
    
    def infer_memory_owners(self, total_wires:  int, ancilla_inds: list[int]) -> tuple[dict[str,dict[int,int]], dict[str,dict[int,int]]]:
        """ Returns (data_owners, ancilla_owners) where each is node_name → { wire_index: slot_index_in_memory_array }.

        Args:
            total_wires (int): The total number of wires (qubits) in the system.
            ancilla_inds (list[int]): The list of indices for the ancilla qubits.
        """
        qubit_to_node = self.infer_qubit_to_node(total_wires)

        data_owners    = {name:{} for name in self.nodes.keys()}

        for q, owner in qubit_to_node.items():
            slot = len(data_owners[owner])
            data_owners[owner][q] = slot

        return data_owners

    def get_timeline(self) -> Timeline:
        return self.tl

    def get_nodes(self) -> dict[str, list[Node]]:
        return self.nodes  
