"""Module defining the Clifford backend."""

import collections
from functools import cache, reduce
from importlib.util import find_spec, module_from_spec
from itertools import product
from typing import Union

import numpy as np

from qibo import gates
from qibo.backends.numpy import NumpyBackend
from qibo.config import raise_error


class CliffordBackend(NumpyBackend):
    """Backend for the simulation of Clifford circuits following
    `Aaronson & Gottesman (2004) <https://arxiv.org/abs/quant-ph/0406196>`_.

    Args:
        :class:`qibo.backends.abstract.Backend`: Backend used for the calculation.
    """

    def __init__(self, engine=None):
        super().__init__()

        if engine == "stim":
            import stim  # pylint: disable=C0415

            engine = "numpy"
            self.platform = "stim"
            self._stim = stim
        else:
            if engine is None:
                from qibo.backends import (  # pylint: disable=C0415
                    _check_backend,
                    _get_engine_name,
                )

                engine = _get_engine_name(_check_backend(engine))

            self.platform = engine

        spec = find_spec("qibo.backends._clifford_operations")
        self.engine = module_from_spec(spec)
        spec.loader.exec_module(self.engine)

        if engine == "numpy":
            pass
        elif engine == "numba":
            from numba import set_num_threads

            set_num_threads(1)

            from qibojit.backends import (  # pylint: disable=C0415
                clifford_operations_cpu,
            )

            for method in dir(clifford_operations_cpu):
                setattr(self.engine, method, getattr(clifford_operations_cpu, method))
        elif engine == "cupy":  # pragma: no cover
            from qibojit.backends import (  # pylint: disable=C0415
                clifford_operations_gpu,
            )

            for method in dir(clifford_operations_gpu):
                setattr(self.engine, method, getattr(clifford_operations_gpu, method))
        else:
            raise_error(
                NotImplementedError,
                f"Backend `{engine}` is not supported for Clifford Simulation.",
            )

        self.np = self.engine.np

        self.name = "clifford"

    def cast(self, x, dtype=None, copy: bool = False):
        """Cast an object as the array type of the current backend.

        Args:
            x: Object to cast to array.
            dtype (optional): data type of the array or tensor. If ``None``, defaults
                to the default data type of the current backend. Defaults to ``None``.
            copy (bool, optional): If ``True`` a copy of the object is created in memory.
                Defaults to ``False``.
        """
        return self.engine.cast(x, dtype=dtype, copy=copy)

    def calculate_frequencies(self, samples):
        res, counts = self.engine.np.unique(samples, return_counts=True)
        # The next two lines are necessary for the GPU backends
        res = [int(r) if not isinstance(r, str) else r for r in res]
        counts = [int(v) for v in counts]

        return collections.Counter(dict(zip(res, counts)))

    def zero_state(self, nqubits: int, i_phase=False):
        """Construct the zero state :math`\\ket{00...00}`.

        Args:
            nqubits (int): number of qubits.
            i_phase (bool, optional): If ``True``, the symplectic matrix will
                have two phase columns as in Dehaene-De Moor format.
                If ``False``, the symplectic matrix will have one phase column
                as in Aaronson-Gottesman format. Defaults to ``False``.

        Returns:
            ndarray: Symplectic matrix for the zero state.
        """
        identity = self.np.eye(nqubits)
        ncols = 2 * nqubits + 2 if i_phase else 2 * nqubits + 1

        symplectic_matrix = self.np.zeros((2 * nqubits + 1, ncols), dtype=bool)
        symplectic_matrix[:nqubits, :nqubits] = self.np.copy(identity)
        symplectic_matrix[nqubits:-1, nqubits : 2 * nqubits] = self.np.copy(identity)
        return symplectic_matrix

    def _clifford_pre_execution_reshape(self, state):
        """Reshape the symplectic matrix to the shape needed by the engine before circuit execution.

        Args:
            state (ndarray): Input state.

        Returns:
            ndarray: Reshaped state.
        """
        return self.engine._clifford_pre_execution_reshape(  # pylint: disable=protected-access
            state
        )

    def _clifford_post_execution_reshape(self, state, nqubits: int):
        """Reshape the symplectic matrix to the shape needed by the engine after circuit execution.

        Args:
            state (ndarray): Input state.
            nqubits (int): Number of qubits.

        Returns:
            ndarray: Reshaped state.
        """
        return self.engine._clifford_post_execution_reshape(  # pylint: disable=protected-access
            state, nqubits
        )

    def apply_gate_clifford(self, gate, symplectic_matrix, nqubits):
        """Apply a gate to a symplectic matrix."""
        if isinstance(gate, gates.Unitary):
            return self.apply_unitary(gate, symplectic_matrix, nqubits)

        operation = getattr(self.engine, gate.__class__.__name__)

        kwargs = {}
        for param_name in ["theta", "phi"]:
            if param_name in gate.init_kwargs:
                kwargs[param_name] = gate.init_kwargs[param_name]

        return operation(symplectic_matrix, *gate.init_args, nqubits, **kwargs)

    def apply_unitary(self, gate, symplectic_matrix, nqubits):
        """Apply a unitary gate to a symplectic matrix following
        `Dehaene & Moor (2003) <https://arxiv.org/abs/quant-ph/0304125>`_."""
        qubit_indices = list(gate.qubits)
        m = len(qubit_indices)
        matrix = gate._parameters[0]
        symplectic_m, phase_h_m = self._compute_symplectic_matrix(matrix, m)
        phase_d_m = self._get_phase_vector_dk(symplectic_m, m)

        symplectic_n = self._embed_clifford(symplectic_m, nqubits, qubit_indices)
        phase_h_n = self._embed_phase_vector(phase_h_m, nqubits, qubit_indices)
        phase_d_n = self._embed_phase_vector(phase_d_m, nqubits, qubit_indices)

        symplectic_gate = [symplectic_n, phase_h_n, phase_d_n]
        for i in range(2 * nqubits):
            symplectic_pauli = [
                symplectic_matrix[i, : 2 * nqubits],
                symplectic_matrix[i, 2 * nqubits],
                symplectic_matrix[i, 2 * nqubits + 1],
            ]
            symplectic_pauli = self._conjugate_pauli(
                symplectic_gate, symplectic_pauli, nqubits
            )

            symplectic_matrix[i, : 2 * nqubits] = symplectic_pauli[0]
            symplectic_matrix[i, 2 * nqubits] = symplectic_pauli[1]
            symplectic_matrix[i, 2 * nqubits + 1] = symplectic_pauli[2]

        return symplectic_matrix

    def _pauli_string_to_matrix(self, pauli_str):
        """Convert Pauli string to matrix (tensor product)."""
        from qibo import matrices  # pylint: disable=C0415

        paulis = {
            pauli: self.engine.cast(getattr(matrices, pauli), dtype=self.dtype)
            for pauli in ("I", "X", "Y", "Z")
        }
        paulis["Y"] = 1j * paulis["Y"]
        pauli_matrices = [paulis.get(p) for p in pauli_str]
        matrix = reduce(self.engine.np.kron, pauli_matrices)
        return matrix

    def _pauli_to_binary(self, pauli_str, nqubits):
        """Convert Pauli string to binary vector of length :math`2*nqubits`."""
        pauli_symplectic = self.np.zeros(2 * nqubits, dtype=self.np.uint8)
        for q, term in enumerate(pauli_str):
            if term in ["X", "Y"]:
                pauli_symplectic[q] = 1
            if term in ["Z", "Y"]:
                pauli_symplectic[q + nqubits] = 1
        return pauli_symplectic

    @staticmethod
    @cache
    def _pauli_generators(m):
        pauli_gens_x, pauli_gens_z = [], []
        for i in range(m):
            p = ["I"] * m
            p = (p, p.copy())
            p[0][i] = "X"
            p[1][i] = "Z"
            pauli_gens_x.append("".join(p[0]))
            pauli_gens_z.append("".join(p[1]))
        return pauli_gens_x + pauli_gens_z

    def _compute_symplectic_matrix(self, unitary, m):
        """Compute the symplectic matrix for Clifford unitary on :math`m` qubits and the phase vector :math`h` of length :math`2m` for Clifford unitary :math`U`.
        :math`h[j] = 0` if :math`U g_j U^\\dagger = i^r p_j` with :math`r=0` or :math`1` else :math`1`.
        """
        pauli_gens = self._pauli_generators(m)

        symplectic = self.np.zeros((2 * m, 2 * m), dtype=self.np.uint8)
        phase_vector = self.np.zeros(2 * m, dtype=self.np.uint8)

        for i, p_str in enumerate(pauli_gens):
            pauli = self._pauli_string_to_matrix(p_str)
            pauli_uconj = unitary @ pauli @ unitary.conj().T

            found = False
            for candidate_str in product("IXYZ", repeat=m):
                candidate_str = "".join(candidate_str)
                candidate_P = self._pauli_string_to_matrix(candidate_str)
                for phase_val, phase_code in zip([1, 1j, -1, -1j], [0, 0, 1, 1]):
                    if self.np.allclose(
                        pauli_uconj, phase_val * candidate_P, atol=1e-10
                    ):
                        phase_vector[i] = phase_code
                        symplectic[i, :] = self._pauli_to_binary(candidate_str, m)
                        found = True
                        break
                if found:
                    break
        return symplectic % 2, phase_vector

    def _get_phase_vector_dk(self, symplectic, m):
        """Compute phase vector :math`d` of length :math`2m` for Clifford unitary :math`U`.
        :math`d[j] = 0` if :math`U g_j U^\\dagger = (-1)^r p_j` with :math`r=0` or :math`1` else :math`1`.
        """
        u_matrix = self.np.zeros((2 * m, 2 * m), dtype=self.np.uint8)
        u_matrix[0:m, m : 2 * m] = self.np.eye(m, dtype=self.np.uint8)
        d = self.np.diag(symplectic @ (u_matrix @ symplectic.T) % 2) % 2
        return d

    def _conjugate_pauli(self, symplectic_gate, symplectic_pauli, nqubits):
        """Compute the conjugate of a Pauli operator under a symplectic transformation."""
        symplectic_matrix, phase_h, phase_d = symplectic_gate
        symplectic_vector, epsilon, delta = symplectic_pauli

        new_symplectic_vector = (symplectic_matrix.T @ symplectic_vector) % 2

        pd_dot_sv = self.np.dot(phase_d, symplectic_vector) % 2
        new_delta = delta ^ pd_dot_sv

        u_matrix = self.np.zeros((2 * nqubits, 2 * nqubits), dtype=self.np.uint8)
        u_matrix[0:nqubits, nqubits : 2 * nqubits] = self.np.eye(
            nqubits, dtype=self.np.uint8
        )

        lows = self.np.tril(
            symplectic_matrix @ (u_matrix @ symplectic_matrix.T)
            ^ self.np.outer(phase_d, phase_d)
        )

        ph_dot_sv = self.np.dot(phase_h, symplectic_vector) % 2
        sv_lows_sv = self.np.dot(symplectic_vector, lows @ symplectic_vector) % 2
        delta_pd_dot_sv = (delta * pd_dot_sv) % 2
        new_epsilon = epsilon ^ ph_dot_sv ^ sv_lows_sv ^ delta_pd_dot_sv
        return new_symplectic_vector, new_epsilon, new_delta

    def _convert_dehaene_to_aaronson(self, dehaene_tableau):
        """
        Convert Dehaene-De Moor tableau with two phase columns to Aaronson-Gottesman format.

        Dehaene-De Moor format:
        - Columns :math`0` to :math`n-1`: X components
        - Columns :math`n` to :math`2n-1`: Z components
        - Column :math`2n`: real phase/sign (:math`0=+1`, :math`1=-1`)
        - Column :math`2n+1`: :math`i`-phase (powers of :math`i`)

        Aaronson-Gottesman format:
        - Columns :math`0` to :math`n-1`: X components
        - Columns :math`n` to :math`2n-1`: Z components
        - Column :math`2n`: real phase (:math`0=+1`, :math`1=-1`)

        Args:
            dehaene_tableau (ndarray): array of shape :math`(2n+1, 2n+2)` in extended Dehaene-De Moor format.

        Returns:
            (ndarray): aaronson_tableau of shape :math`(2n+1, 2n+1)`.
        """
        n_rows, n_cols = dehaene_tableau.shape
        n = (n_cols - 2) // 2

        X_part = dehaene_tableau[:, :n]
        Z_part = dehaene_tableau[:, n : 2 * n]
        real_phases = dehaene_tableau[:, -2]
        i_phases = dehaene_tableau[:, -1]

        y_count = self.np.sum(X_part[:-1] & Z_part[:-1], axis=-1)
        total_i_power = (i_phases[:-1] + y_count) % 4

        final_real_phases = real_phases.copy()
        indices = total_i_power == 2
        final_real_phases[: n_rows - 1][indices] = (
            final_real_phases[: n_rows - 1][indices] + 1
        ) % 2

        aaronson_tableau = np.column_stack([X_part, Z_part, final_real_phases])
        return self.cast(aaronson_tableau, dtype=aaronson_tableau.dtype)

    def _embed_clifford(self, symplectic_m, n, qubit_indices):
        """Embed m-qubit symplectic :math`S_U_m` into n-qubit system at qubit_indices."""
        symplectic_n = self.np.eye(2 * n, dtype=self.np.uint8)

        x_indices = qubit_indices
        z_indices = [q + n for q in qubit_indices]
        full_indices = x_indices + z_indices

        symplectic_n[np.ix_(full_indices, full_indices)] = symplectic_m

        return symplectic_n % 2

    def _embed_phase_vector(self, phase_m, n, qubit_indices):
        """Embed m-qubit phase vector into n-qubit system."""
        phase_n = self.np.zeros(2 * n, dtype=self.np.uint8)
        m = len(qubit_indices)

        qubit_indices = np.array(qubit_indices)
        phase_n[qubit_indices] = phase_m[:m]
        phase_n[qubit_indices + n] = phase_m[m:]

        return phase_n

    def apply_channel(self, channel, state, nqubits):
        probabilities = channel.coefficients + (1 - np.sum(channel.coefficients),)
        index = self.np.random.choice(
            range(len(probabilities)), size=1, p=probabilities
        )[0]
        index = int(index)
        if index != len(channel.gates):
            gate = channel.gates[index]
            state = gate.apply_clifford(self, state, nqubits)
        return state

    def execute_circuit(  # pylint: disable=R1710
        self, circuit, initial_state=None, nshots: int = 1000
    ):
        """Execute a Clifford circuits.

        Args:
            circuit (:class:`qibo.models.circuit.Circuit`): Input circuit.
            initial_state (ndarray, optional): The ``symplectic_matrix`` of the initial state.
                If ``None``, defaults to the zero state. Defaults to ``None``.
            nshots (int, optional): Number of shots to perform if ``circuit`` has measurements.
                Defaults to :math:`10^{3}`.

        Returns:
            :class:`qibo.quantum_info.clifford.Clifford`: Object storing to the final results.
        """
        from qibo.quantum_info.clifford import Clifford  # pylint: disable=C0415

        if self.platform == "stim":
            circuit_stim = self._stim.Circuit()  # pylint: disable=E1101
            for gate in circuit.queue:
                circuit_stim.append(gate.__class__.__name__, list(gate.qubits))

            x_destab, z_destab, x_stab, z_stab, x_phases, z_phases = (
                self._stim.Tableau.from_circuit(  # pylint: disable=no-member
                    circuit_stim
                ).to_numpy()
            )
            symplectic_matrix = np.block([[x_destab, z_destab], [x_stab, z_stab]])
            symplectic_matrix = np.c_[symplectic_matrix, np.r_[x_phases, z_phases]]

            return Clifford(
                symplectic_matrix,
                measurements=circuit.measurements,
                nshots=nshots,
                _backend=self,
            )

        for gate in circuit.queue:
            if (
                not gate.clifford
                and not gate.__class__.__name__ == "M"
                and not isinstance(gate, gates.PauliNoiseChannel)
            ):
                raise_error(RuntimeError, "Circuit contains non-Clifford gates.")

        if circuit.repeated_execution and nshots != 1:
            return self.execute_circuit_repeated(circuit, nshots, initial_state)

        try:
            nqubits = circuit.nqubits
            i_phase = False
            if any(isinstance(gate, gates.Unitary) for gate in circuit.queue):
                i_phase = True
            state = (
                self.zero_state(nqubits, i_phase)
                if initial_state is None
                else initial_state
            )
            if i_phase is False:
                state = self._clifford_pre_execution_reshape(state)
            for gate in circuit.queue:
                if i_phase:
                    if isinstance(gate, gates.M):
                        if gate.collapse:
                            raise_error(
                                NotImplementedError,
                                "Collapsing measurements with `gates.Unitary` are not implemented in the `CliffordBackend`.",
                            )
                    elif not isinstance(gate, gates.Unitary):
                        gate = gates.Unitary(gate.matrix(backend=self), *gate.qubits)
                gate.apply_clifford(self, state, nqubits)
            if i_phase:
                state = self._convert_dehaene_to_aaronson(state)
            else:
                state = self._clifford_post_execution_reshape(state, nqubits)
            clifford = Clifford(
                state,
                measurements=circuit.measurements,
                nshots=nshots,
                _backend=self,
            )
            return clifford

        except self.oom_error:  # pragma: no cover
            raise_error(
                RuntimeError,
                f"State does not fit in {self.device} memory."
                "Please switch the execution device to a "
                "different one using ``qibo.set_device``.",
            )

    def execute_circuit_repeated(self, circuit, nshots: int = 1000, initial_state=None):
        """Execute a Clifford circuits ``nshots`` times.

        This is used for all the simulations that involve repeated execution.
        For instance when collapsing measurement or noise channels are present.

        Args:
            circuit (:class:`qibo.models.circuit.Circuit`): input circuit.
            initial_state (ndarray, optional): Symplectic_matrix of the initial state.
                If ``None``, defaults to :meth:`qibo.backends.clifford.CliffordBackend.zero_state`.
                Defaults to ``None``.
            nshots (int, optional): Number of times to repeat the execution.
                Defaults to :math:`1000`.

        Returns:
            :class:`qibo.quantum_info.clifford.Clifford`: Object storing to the final results.
        """
        from qibo.quantum_info.clifford import Clifford  # pylint: disable=C0415

        circuit_copy = circuit.copy()
        samples = []
        for _ in range(nshots):
            res = self.execute_circuit(circuit_copy, initial_state, nshots=1)
            for measurement in circuit_copy.measurements:
                measurement.result.reset()
            samples.append(res.samples())
        samples = self.np.vstack(samples)

        for meas in circuit.measurements:
            meas.result.register_samples(samples[:, meas.target_qubits])

        result = Clifford(
            self.zero_state(circuit.nqubits),
            measurements=circuit.measurements,
            nshots=nshots,
            _backend=self,
        )
        result.symplectic_matrix, result._samples = None, None

        return result

    def sample_shots(
        self,
        state,
        qubits: Union[tuple, list],
        nqubits: int,
        nshots: int,
        collapse: bool = False,
    ):  # pylint: disable=W0221
        """Sample shots by measuring selected ``qubits`` in symplectic matrix of a ``state``.

        Args:
            state (ndarray): symplectic matrix from which to sample shots from.
            qubits: (tuple or list): qubits to measure.
            nqubits (int): total number of qubits of the state.
            nshots (int): number of shots to sample.
            collapse (bool, optional): If ``True`` the input state is going to be
                collapsed with the last shot. Defaults to ``False``.

        Returns:
            (ndarray): Samples shots.
        """
        if isinstance(qubits, list):
            qubits = tuple(qubits)

        if collapse:
            samples = [self.engine.M(state, qubits, nqubits) for _ in range(nshots - 1)]
            samples.append(self.engine.M(state, qubits, nqubits, collapse))
        else:
            samples = [self.engine.M(state, qubits, nqubits) for _ in range(nshots)]

        return self.engine.cast(samples, dtype=int)

    def symplectic_matrix_to_generators(
        self, symplectic_matrix, return_array: bool = False
    ):
        """Extract the stabilizers and destabilizers generators from symplectic matrix.

        Args:
            symplectic_matrix (ndarray): The input symplectic_matrix.
            return_array (bool, optional): If ``True`` returns the generators as ``ndarrays``.
                If ``False``, generators are returned as strings. Defaults to ``False``.

        Returns:
            (list, list): Extracted generators and their corresponding phases, respectively.
        """
        bits_to_gate = {"00": "I", "01": "X", "10": "Z", "11": "Y"}

        nqubits = int((symplectic_matrix.shape[1] - 1) / 2)
        phases = (-1) ** symplectic_matrix[:-1, -1].astype(np.int16)
        tmp = 1 * symplectic_matrix[:-1, :-1]
        X, Z = tmp[:, :nqubits], tmp[:, nqubits:]
        generators = []
        for x, z in zip(X, Z):
            paulis = [bits_to_gate[f"{zz}{xx}"] for xx, zz in zip(x, z)]
            if return_array:
                from qibo import matrices  # pylint: disable=C0415

                paulis = [self.cast(getattr(matrices, p)) for p in paulis]
                matrix = reduce(self.np.kron, paulis)
                generators.append(matrix)
            else:
                generators.append("".join(paulis))

        if return_array:
            generators = self.cast(generators)

        return generators, phases
