# This code is a Qiskit project.
#
# (C) Copyright IBM 2025.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""GroupMeasIntoBoxes"""

from __future__ import annotations

import itertools
from collections import defaultdict
from typing import Literal

from qiskit.circuit import Annotation, Bit, Qubit
from qiskit.dagcircuit import DAGCircuit
from qiskit.transpiler.basepasses import TransformationPass
from qiskit.transpiler.exceptions import TranspilerError

from ...aliases import DAGOpNode
from ...annotations import BasisTransform, Twirl
from .utils import make_and_insert_box, validate_op_is_supported

SUPPORTED_ANNOTATIONS = ["twirl", "basis_transform", "all"]
"""The supported values of ``annotations``."""


class GroupMeasIntoBoxes(TransformationPass):
    """Collect the measurements in a circuit inside boxes.

    This pass collects all of the measurements in the input circuit in boxes, together with the
    single-qubit gates that precede them. To assign the measurements to these boxes, it
    uses a greedy collection strategy that tries to collect measurements in the earliest possible
    box that they can fit.

    Args:
        annotations: The annotations placed on the measurement boxe. The supported values are:
            * ``'twirl'`` for a :class:`~.Twirl` annotation.
            * ``'basis_transform'`` for a :class:`~.BasisTransform` annotation with mode
                ``measure``.
            * ``'all'`` for both :class:`~.Twirl` and :class:`~.BasisTransform` annotations.
        prefix_ref: A prefix to all the :class:`.BasisTransform.ref` generated by this class. Every
            ``ref`` is generated by appending a counter to ``prefix_ref``. In order to avoid
            collisions, the counter is shared across all the instances of this class and it is
            incremented every time that a new ``ref`` is created.

    .. note::
        Barriers, boxes, and multi-qubit gates that are present in the input circuit act as
        delimiters. This means that when one of these delimiters stands between a group of
        single-qubit gates and a measurement, those single-qubit gates are not added to the box.
    """

    _REF_COUNTER = itertools.count()

    def __init__(
        self,
        annotations: Literal["twirl", "basis_transform", "all"] = "twirl",
        prefix_ref: str = "basis",
    ):
        TransformationPass.__init__(self)

        if annotations not in SUPPORTED_ANNOTATIONS:
            raise ValueError(
                f"{annotations} is not a valid input for field 'annotations'. "
                f"The supported values are '{SUPPORTED_ANNOTATIONS}.'"
            )

        self.annotations = annotations
        self.prefix_ref = prefix_ref

    def _make_annotations(self) -> list[Annotation]:
        """A helper function to make annotations for the boxes created by this pass."""
        if self.annotations == "twirl":
            return [Twirl()]
        if self.annotations == "basis_transform":
            return [BasisTransform(ref=f"{self.prefix_ref}{next(self._REF_COUNTER)}")]
        if self.annotations == "all":
            return [Twirl(), BasisTransform(ref=f"{self.prefix_ref}{next(self._REF_COUNTER)}")]

        raise TranspilerError(
            f"{self.annotations} is not a valid input for field 'annotations'. "
            f"The supported values are '{SUPPORTED_ANNOTATIONS}.'"
        )

    def run(self, dag: DAGCircuit) -> DAGCircuit:
        """Collect the operations in the dag inside left-dressed boxes.

        The collection strategy undertakes the following steps:
            *   Loop through the DAG's op nodes in topological order.
            *   Group together single-qubit gate nodes and measurement nodes that need to be
                placed in the same  box.
            *   Whenever a node can be placed in more than one group, place it in the earliest
                possible group--where "earliest" is with reference to opological ordering.
            *   When looping is complete, replace each group with a box.
        """
        # A map to temporarily store single-qubit gate nodes before inserting them into a group
        cached_gates_1q: dict[Qubit, list[DAGOpNode]] = defaultdict(list)

        # A list of groups that need to be placed in the same box, expressed as a dict for fast
        # access. Every node in each group either contains a single-qubit gate or a measurement
        # --when constructing this dictionary, we explicitly leave out nodes that contain different
        # ops.
        groups: dict[int, list[DAGOpNode]] = defaultdict(list)

        # A map from bits (qubits and clbits) to the index of the earliest group that is able to
        # collect operations on those bits
        group_indices: dict[Bit, int] = defaultdict(int)

        for node in dag.topological_op_nodes():
            validate_op_is_supported(node)

            # The index of the earliest group able to collect ops on all the bits in this node
            group_idx: int = max(group_indices[bit] for bit in node.qargs + node.cargs)

            if (name := node.op.name) in ["barrier", "box"] or (
                node.is_standard_gate() and node.op.num_qubits == 2
            ):
                # Flush the single-qubit gate nodes without placing them in a group
                for bit in node.qargs + node.cargs:
                    cached_gates_1q.pop(bit, [])

                    # Update the trackers
                    group_indices[bit] = group_idx
            elif name == "measure":
                # Flush the cached one-qubit gates and measurement into a group
                for qubit in node.qargs:
                    groups[group_idx] += cached_gates_1q.pop(qubit, [])
                groups[group_idx].append(node)

                # Update trackers
                for qubit in node.qargs:
                    group_indices[qubit] = group_idx + 1
            elif node.is_standard_gate() and node.op.num_qubits == 1:
                # Cache the node
                cached_gates_1q[node.qargs[0]].append(node)
            else:
                raise TranspilerError(f"'{name}' operation is not supported.")

        for nodes in groups.values():
            cargs = [carg for node in nodes for carg in node.cargs]
            if len(cargs) != len(set(cargs)):
                raise TranspilerError(
                    "Cannot twirl more than one measurement on the same classical bit."
                )

            make_and_insert_box(dag, nodes, self._make_annotations())

        return dag
