import json
import numpy as np

from .topology import Topology as Topo
from ..kernel.timeline import Timeline
from ..constants import SPEED_OF_LIGHT

from .qlan.orchestrator import QlanOrchestratorNode
from .qlan.client import QlanClientNode


class QlanStarTopo(Topo):
    """Class for generating qlan topology with a single orchestrator and multiple clients

    Class QlanStarTopo is the child of class Topology. Orchestrator nodes, Client nodes, quantum  channels,
    classical channels and timeline for simulation could be generated by using this class.

    Attributes:
        nodes (dict[str, list[Node]]): mapping of type of node to a list of same type node.
        qchannels (list[QuantumChannel]): list of quantum channel objects in network.
        cchannels (list[ClassicalChannel]): list of classical channel objects in network.
        tl (Timeline): the timeline used for simulation
        orchestrator_nodes (list[QlanOrchestratorNode]): list of orchestrator nodes in the network.
        client_nodes (list[QlanClientNode]): list of client nodes in the network.
        remote_memories_array (list[Memory]): list of remote memory objects in the network.
        n_local_memories (int): number of local memories in each orchestrator node.
        n_clients (int): number of client nodes in the network.
        meas_bases (str): measurement bases used by the client nodes.
        memo_fidelity_orch (float): fidelity of the memories in the orchestrator nodes.
        memo_freq_orch (int): frequency of the memories in the orchestrator nodes.
        memo_efficiency_orch (float): efficiency of the memories in the orchestrator nodes.
        memo_coherence_orch (float): coherence of the memories in the orchestrator nodes.
        memo_wavelength_orch (int): wavelength of the memories in the orchestrator nodes.
        memo_fidelity_client (float): fidelity of the memories in the client nodes.
        memo_freq_client (int): frequency of the memories in the client nodes.
        memo_efficiency_client (float): efficiency of the memories in the client nodes.
        memo_coherence_client (float): coherence of the memories in the client nodes.
        memo_wavelength_client (int): wavelength of the memories in the client nodes.
    """
    MEET_IN_THE_MID = "meet_in_the_middle"
    ORCHESTRATOR = "QlanOrchestratorNode"
    CLIENT = "QlanClientNode"
    LOCAL_MEMORIES = "local_memories"
    CLIENT_NUMBER = "client_number"
    MEM_FIDELITY_ORCH = "memo_fidelity_orch"
    MEM_FREQUENCY_ORCH = "memo_frequency_orch"
    MEM_EFFICIENCY_ORCH = "memo_efficiency_orch"
    MEM_COHERENCE_ORCH = "memo_coherence_orch"
    MEM_WAVELENGTH_ORCH = "memo_wavelength_orch"
    MEM_FIDELITY_CLIENT = "memo_fidelity_client"
    MEM_FREQUENCY_CLIENT = "memo_frequency_client"
    MEM_EFFICIENCY_CLIENT = "memo_efficiency_client"
    MEM_COHERENCE_CLIENT = "memo_coherence_client"
    MEM_WAVELENGTH_CLIENT = "memo_wavelength_client"
    MEASUREMENT_BASES = "measurement_bases"
    MEM_SIZE = "memo_size"

    def __init__(self, conf_file_name: str):
        self.orchestrator_nodes = []
        self.client_nodes = []
        self.remote_memories_array = []
        super().__init__(conf_file_name)


    def _load(self, filename: str):
        with open(filename) as fh:
            config = json.load(fh)

        self._get_templates(config)
        self._add_parameters(config)

        # quantum connections are only supported by sequential simulation so far
        self._add_qconnections(config)

        self._add_timeline(config)
        self._add_nodes(config)
        self._add_qchannels(config)
        self._add_cchannels(config)
        self._add_cconnections(config)
        self._add_protocols()

    def _add_timeline(self, config: dict):
        stop_time = config.get(Topo.STOP_TIME, float('inf'))
        self.tl = Timeline(stop_time)
    
    def _add_parameters(self, config: dict):

        self.n_local_memories = config.get(self.LOCAL_MEMORIES, 1)
        self.n_clients = config.get(self.CLIENT_NUMBER, 1)
        self.meas_bases = config.get(self.MEASUREMENT_BASES, 'zz')

        self.memo_fidelity_orch = config.get(self.MEM_FIDELITY_ORCH, 0.9)
        self.memo_freq_orch = config.get(self.MEM_FREQUENCY_ORCH, 2000)
        self.memo_efficiency_orch = config.get(self.MEM_EFFICIENCY_ORCH, 1)
        self.memo_coherence_orch = config.get(self.MEM_COHERENCE_ORCH, -1)
        self.memo_wavelength_orch = config.get(self.MEM_WAVELENGTH_ORCH, 500)

        self.memo_fidelity_client = config.get(self.MEM_FIDELITY_CLIENT, 0.9)
        self.memo_freq_client = config.get(self.MEM_FREQUENCY_CLIENT, 2000)
        self.memo_efficiency_client = config.get(self.MEM_EFFICIENCY_CLIENT, 1)
        self.memo_coherence_client = config.get(self.MEM_COHERENCE_CLIENT, -1)
        self.memo_wavelength_client = config.get(self.MEM_WAVELENGTH_CLIENT, 500)

    def _add_nodes(self, config: dict):
        self._add_client_nodes(config)
        self._add_orchestrator_nodes(config)

    def _add_client_nodes(self, config: dict):
        for node in config[Topo.ALL_NODE]:
            seed = node[Topo.SEED]
            node_type = node[Topo.TYPE]
            name = node[Topo.NAME]
            template_name = node.get(Topo.TEMPLATE, None)
            template = self.templates.get(template_name, {})

            if node_type == self.CLIENT:
                node_obj = QlanClientNode(name, 
                                            self.tl, 
                                            1, 
                                            self.memo_fidelity_client, 
                                            self.memo_freq_client, 
                                            self.memo_efficiency_client, 
                                            self.memo_coherence_client, 
                                            self.memo_wavelength_client)
                node_obj.set_seed(seed)
                node_memo = node_obj.get_components_by_type("Memory")[0]
                self.remote_memories_array.append(node_memo)
                self.client_nodes.append(node_obj)
                self.nodes[node_type].append(node_obj)
            
            elif node_type == self.ORCHESTRATOR:
                pass

            else:
                raise ValueError(f"Unknown type of node '{node_type}'")

            
    def _add_orchestrator_nodes(self, config: dict):
        for node in config[Topo.ALL_NODE]:
            seed = node[Topo.SEED]
            node_type = node[Topo.TYPE]
            name = node[Topo.NAME]
            template_name = node.get(Topo.TEMPLATE, None)
            template = self.templates.get(template_name, {})

            if node_type == self.ORCHESTRATOR:
                node_obj = QlanOrchestratorNode(name, 
                                                self.tl,
                                                self.n_local_memories, 
                                                self.remote_memories_array, self.memo_fidelity_orch, 
                                                self.memo_freq_orch, 
                                                self.memo_efficiency_orch, 
                                                self.memo_coherence_orch, 
                                                self.memo_wavelength_orch)
                node_obj.set_seed(seed)
                node_obj.update_bases(self.meas_bases)
                self.orchestrator_nodes.append(node_obj)
                self.nodes[node_type].append(node_obj)
            
            elif node_type == self.CLIENT:
                pass
            
            else:
                raise ValueError(f"Unknown type of node '{node_type}'")
            
    def _add_protocols(self):
        for orch in self.orchestrator_nodes:
            orch.resource_manager.create_protocol()
        for client in self.client_nodes:
            client.resource_manager.create_protocol()

    def _add_qconnections(self, config: dict):
        '''generate bsm_info, qc_info, and cc_info for the q_connections
        '''
        for q_connect in config.get(Topo.ALL_Q_CONNECT, []):
            node1 = q_connect[Topo.CONNECT_NODE_1]
            node2 = q_connect[Topo.CONNECT_NODE_2]
            attenuation = q_connect[Topo.ATTENUATION]
            distance = q_connect[Topo.DISTANCE] // 2
            channel_type = q_connect[Topo.TYPE]
            cc_delay = []                                   
            for cc in config.get(self.ALL_C_CHANNEL, []):   
                if cc[self.SRC] == node1 and cc[self.DST] == node2:
                    delay = cc.get(self.DELAY, cc.get(self.DISTANCE, 1000) / SPEED_OF_LIGHT)
                    cc_delay.append(delay)
                elif cc[self.SRC] == node2 and cc[self.DST] == node1:
                    delay = cc.get(self.DELAY, cc.get(self.DISTANCE, 1000) / SPEED_OF_LIGHT)
                    cc_delay.append(delay)

            for cc in config.get(self.ALL_C_CONNECT, []):  
                if (cc[self.CONNECT_NODE_1] == node1 and cc[self.CONNECT_NODE_2] == node2) \
                        or (cc[self.CONNECT_NODE_1] == node2 and cc[self.CONNECT_NODE_2] == node1):
                    delay = cc.get(self.DELAY, cc.get(self.DISTANCE, 1000) / SPEED_OF_LIGHT)
                    cc_delay.append(delay)
            if len(cc_delay) == 0:
                assert 0, q_connect
            cc_delay = np.mean(cc_delay) // 2

