"""[CUDA-Q](https://github.com/NVIDIA/cuda-quantum) based quantum circuit backend."""

# AUTOGENERATED! DO NOT EDIT! File to edit: ../../../src/platform/backends/circuits_cudaq.ipynb.

# %% auto 0
__all__ = ['ParametrizedCudaqKernel', 'CircuitsCudaqBackend']

# %% ../../../src/platform/backends/circuits_cudaq.ipynb 2
from ...imports import *
from .base_backend import BaseBackend
from ..circuits_instructions import CircuitInstructions

import cudaq

# %% ../../../src/platform/backends/circuits_cudaq.ipynb 4
@dataclass
class ParametrizedCudaqKernel:
    kernel: cudaq.kernel
    params: list[float] # currently only support 1 angle per gate

# %% ../../../src/platform/backends/circuits_cudaq.ipynb 6
class CircuitsCudaqBackend(BaseBackend):

    BASIC_BACKEND_TYPE = type[cudaq.kernel]

    def __init__(self, target: str = "qpp-cpu") -> None:
        cudaq.reset_target()
        cudaq.set_target(target) # 'nvidia'
    
    def backend_to_genqc(self):
        raise NotImplementedError("Not implemeted cudaq to genQC.")

    # Has to match with insides of belows kernel
    KERNEL_VOCABULARY = {"h":1, 
                         "cx":2, 
                         "z":3, 
                         "x":4, 
                         "y":5, 
                         "ccx":6, 
                         "swap":7,
                         "rx":8,
                         "ry":9,
                         "rz":10,
                         "cp":11,} 

    def _construct_kernel(self,
                          gate_list: List[str],
                          target_1_nodes_list: List[int],
                          target_2_nodes_list: List[int],
                          control_1_nodes_list: List[int],
                          control_2_nodes_list: List[int]
                         ) -> cudaq.kernel:
        """Construct a `cudaq.kernel` from provided paramters."""
  
        num_gates = len(gate_list)
        gate_list = [self.KERNEL_VOCABULARY[g] for g in gate_list]

        # Note: `@cudaq.kernel` decorator has a overhead of 20ms, regardless of the for-loop inside
        
        @cudaq.kernel
        def place_gate_kernel(gate: int, 
                              qvector: cudaq.qview,
                              target_1: int, 
                              target_2: int, 
                              control_1: int, 
                              control_2: int,
                              theta: float):       
            
            if   gate == 1: h(qvector[target_1])
            elif gate == 2: cx(qvector[control_1], qvector[target_1])
            elif gate == 3: z(qvector[target_1])
            elif gate == 4: x(qvector[target_1])
            elif gate == 5: y(qvector[target_1])
            elif gate == 6: x.ctrl(qvector[control_1], qvector[control_2], qvector[target_1])
            elif gate == 7: swap(qvector[target_1], qvector[target_2])
                
            elif gate == 8:  rx(theta, qvector[target_1])
            elif gate == 9:  ry(theta, qvector[target_1])
            elif gate == 10: rz(theta, qvector[target_1])
                
            elif gate == 11: 
                # R1 applies the unitary transformation; i.e. it is a phase gate
                # R1(λ) = | 1     0    |
                #         | 0  exp(iλ) |
                r1.ctrl(theta, qvector[target_1], qvector[target_2])
                
      
        @cudaq.kernel  
        def kernel(input_state: list[complex], thetas: list[float]):
            qvector = cudaq.qvector(input_state)
            for i in range(num_gates):
                place_gate_kernel(gate_list[i], qvector, target_1_nodes_list[i], target_2_nodes_list[i], control_1_nodes_list[i], control_2_nodes_list[i], thetas[i])
    
        return kernel

    def check_error_circuit(self, 
                            gate: str, 
                            num_target_nodes: int, 
                            num_control_nodes: int) -> bool:
        """Check number of connections of given gate. Used to check for error circuits."""

        if gate not in self.KERNEL_VOCABULARY:
            raise NotImplementedError(f"Unknown gate {gate}, not in `self.KERNEL_VOCABULARY`.")
            
        if gate in ["h", "z", "x", "y", "rx", "ry", "rz"]:
            if num_target_nodes != 1 or num_control_nodes !=0: return False

        elif gate in ["cx"]:
            if num_target_nodes != 1 or num_control_nodes !=1: return False

        elif gate in ["ccx"]:
            if num_target_nodes != 1 or num_control_nodes !=2: return False

        elif gate in ["swap", "cp"]:
            if num_target_nodes != 2 or num_control_nodes !=0: return False

        else:
            raise NotImplementedError(f"Unknown gate {gate}, implemetation is faulty!")

        return True
  
    def genqc_to_backend(self, 
                         instructions: CircuitInstructions,
                         **kwargs) -> cudaq.kernel:
        """Convert given genQC `CircuitInstructions` to a `cudaq.kernel`."""

        _params = torch.tensor([
                                instruction.params if instruction.params else torch.nan 
                                for instruction in instructions.data
                               ])   # ... [seq, nP]

        if not torch.isnan(_params).any():
            assert _params.shape[1] == 1  #only support nP=1 for now
            _params = _params.squeeze()

        #--------------------
        
        # num_qubits = instructions.num_qubits
        num_gates  = instructions.length

        # @cudaq.kernel can only take list[int] and no str directly
        # -> we have to map everything to list[int]        
        # set default value to 9999 so an error wil be raised if we have a faulty tensor encoding
        
        gate_list = []
        target_1_nodes_list  = [9999] * num_gates
        target_2_nodes_list  = [9999] * num_gates
        control_1_nodes_list = [9999] * num_gates
        control_2_nodes_list = [9999] * num_gates

        for i, instruction in enumerate(instructions.data):

            gate          = instruction.name.lower()
            control_nodes = instruction.control_nodes
            target_nodes  = instruction.target_nodes
            
            num_target_nodes  = len(target_nodes)
            num_control_nodes = len(control_nodes)
            
            if not self.check_error_circuit(gate, num_target_nodes, num_control_nodes):
                return None
            
            gate_list.append(gate)
  
            if num_target_nodes > 0:
                target_1_nodes_list[i] = target_nodes[0]
                if num_target_nodes > 1: 
                    target_2_nodes_list[i] = target_nodes[1]      
            
            if num_control_nodes > 0:
                control_1_nodes_list[i] = control_nodes[0]  
                if num_control_nodes > 1: 
                    control_2_nodes_list[i] = control_nodes[1]  
                    
        #--------------------
        _kernel = self._construct_kernel(gate_list, target_1_nodes_list, target_2_nodes_list, control_1_nodes_list, control_2_nodes_list)

        return ParametrizedCudaqKernel(kernel=_kernel, params=_params.tolist())
    
    def get_unitary(self, parametrizedCudaqKernel: ParametrizedCudaqKernel, num_qubits: int) -> np.ndarray:
        """Return the unitary matrix of a `cudaq.kernel`. Currently relies on simulation, could change in future releases of cudaq."""

        kernel, thetas = parametrizedCudaqKernel.kernel, parametrizedCudaqKernel.params
        
        N = 2**num_qubits
        U = np.zeros((N, N), dtype=np.complex128)
        
        for j in range(N): 
            state_j    = np.zeros((N), dtype=np.complex128) 
            state_j[j] = 1
            
            U[:, j] = np.array(cudaq.get_state(kernel, state_j, thetas), copy=False)
            
        return U

    def draw(self, parametrizedCudaqKernel: ParametrizedCudaqKernel, num_qubits: int, return_str: bool = False, **kwargs) -> None:
        """Draw the given `cudaq.kernel` using cudaq.""" 

        kernel, thetas = parametrizedCudaqKernel.kernel, parametrizedCudaqKernel.params
        
        c    = [0] * (2**num_qubits)
        c[0] = 1

        s = cudaq.draw(kernel, c, thetas)
        if return_str:
            return s
        print(s)
