# Copyright (C) 2024 qBraid
#
# This file is part of the qBraid-SDK
#
# The qBraid-SDK is free software released under the GNU General Public License v3
# or later. You can redistribute and/or modify it under the terms of the GPL v3.
# See the LICENSE file in the project root or <https://www.gnu.org/licenses/gpl-3.0.html>.
#
# THERE IS NO WARRANTY for the qBraid-SDK, as per Section 15 of the GPL v3.

"""
Module for generating random OpenQASM 3 programs

"""
from typing import Callable, Optional

import numpy as np

from qbraid._version import __version__
from qbraid.exceptions import QbraidError


def create_gateset_qasm(max_operands: int) -> np.ndarray:
    """Gets QASM for gateset with max_operands."""
    q1_gates: list[tuple[str, int, int]] = [
        ("id", 1, 0),
        ("x", 1, 0),
        ("y", 1, 0),
        ("z", 1, 0),
        ("h", 1, 0),
        ("s", 1, 0),
        ("t", 1, 0),
        ("sdg", 1, 0),
        ("tdg", 1, 0),
        ("sx", 1, 0),
        ("rx", 1, 1),
        ("ry", 1, 1),
        ("rz", 1, 1),
        ("p", 1, 1),
        ("u1", 1, 1),
        ("u2", 1, 2),
        ("u3", 1, 3),
        ("reset", 1, 0),
    ]

    q2_gates: list[tuple[str, int, int]] = [
        ("cx", 2, 0),
        ("cy", 2, 0),
        ("cz", 2, 0),
        ("ch", 2, 0),
        ("cp", 2, 1),
        ("crx", 2, 1),
        ("cry", 2, 1),
        ("crz", 2, 1),
        ("swap", 2, 0),
        ("cu", 2, 4),
    ]

    q3_gates: list[tuple[str, int, int]] = [("ccx", 3, 0), ("cswap", 3, 0)]

    gates = q1_gates.copy()

    if max_operands >= 2:
        gates.extend(q2_gates)
    if max_operands >= 3:
        gates.extend(q3_gates)

    gates_array = np.array(
        gates, dtype=[("gate", object), ("num_qubits", np.int64), ("num_params", np.int64)]
    )
    return gates_array


# pylint: disable-next=too-many-arguments
def qasm3_random_from_gates(
    create_gateset: Callable[[int], np.ndarray],
    num_qubits: Optional[int] = None,
    depth: Optional[int] = None,
    max_operands: Optional[int] = None,
    seed: Optional[int] = None,
    measure: bool = False,
) -> str:
    """Generate random OpenQASM 3 program.

    Args:
        create_gateset (Callable): Function to create gateset.
        num_qubits (int): Number of quantum wires.
        depth (int): Layers of operations (i.e., critical path length).
        max_operands (int): Maximum size of gate for each operation.
        seed (int): Seed for random number generator.
        measure (bool): Whether to include measurement gates.

    Raises:
        ValueError: When invalid random circuit options are given.
        QbraidError: When failed to create random OpenQASM 3 program.

    Returns:
        str: OpenQASM 3 program.
    """

    def validate_and_assign(value: Optional[int], name: str) -> int:
        """Validate and assign random circuit option."""
        if value is None:
            return np.random.randint(1, 4)
        if not isinstance(value, int) or value <= 0:
            raise ValueError(f"Invalid random circuit option. '{name}' must be a positive integer.")
        return value

    num_qubits = validate_and_assign(num_qubits, "num_qubits")
    depth = validate_and_assign(depth, "depth")
    max_operands = validate_and_assign(max_operands, "max_operands")

    try:
        if seed is None:
            seed = np.random.randint(0, np.iinfo(np.int32).max)
        np.random.seed(seed)
        rng = np.random.default_rng(seed)

        qasm_code_header = f"""
// Generated by qBraid v{__version__}
OPENQASM 3.0;
include "stdgates.inc";
/*
    seed = {seed}
    num_qubits = {num_qubits}
    depth = {depth}
    max_operands = {max_operands}
*/
"""
        max_operands = min(max_operands, num_qubits)
        rand_circuit = qasm_code_header + f"qubit[{num_qubits}] q;\n"
        if measure:
            rand_circuit += f"bit[{num_qubits}] c;\n"

        qubits = np.arange(num_qubits)
        gates = create_gateset(max_operands)

        for _ in range(depth):
            gate_specs = rng.choice(gates, size=num_qubits)
            cumulative_qubits = np.cumsum(gate_specs["num_qubits"], dtype=np.int64)

            max_index = np.searchsorted(cumulative_qubits, num_qubits, side="right")
            gate_specs = gate_specs[:max_index]
            slack = num_qubits - cumulative_qubits[max_index - 1]
            if slack:
                gates = create_gateset(1)
                slack_gates = rng.choice(gates, size=slack)
                gate_specs = np.hstack((gate_specs, slack_gates))

            q_indices = np.empty(len(gate_specs) + 1, dtype=np.int64)
            p_indices = np.empty(len(gate_specs) + 1, dtype=np.int64)
            q_indices[0] = p_indices[0] = 0
            np.cumsum(gate_specs["num_qubits"], out=q_indices[1:])
            np.cumsum(gate_specs["num_params"], out=p_indices[1:])
            parameters = rng.uniform(0, 2 * np.pi, size=p_indices[-1])
            for i, (gate, _, p) in enumerate(gate_specs):
                if p:
                    params = ",".join(
                        str(parameters[j]) for j in range(p_indices[i], p_indices[i + 1])
                    )
                    qubit_indices = ",".join(
                        f"q[{qubits[j]}]" for j in range(q_indices[i], q_indices[i + 1])
                    )
                    rand_circuit += f"{gate}({params}) {qubit_indices};\n"
                else:
                    qubit_indices = ",".join(
                        f"q[{qubits[j]}]" for j in range(q_indices[i], q_indices[i + 1])
                    )
                    rand_circuit += f"{gate} {qubit_indices};\n"
            qubits = rng.permutation(qubits)

        if measure:
            for i in range(num_qubits):
                rand_circuit += f"c[{i}] = measure q[{i}];\n"

        return rand_circuit

    except Exception as err:
        raise QbraidError("Failed to create random OpenQASM 3 program") from err


def qasm3_random(
    num_qubits: Optional[int] = None,
    depth: Optional[int] = None,
    max_operands: Optional[int] = None,
    seed: Optional[int] = None,
    measure: bool = False,
) -> str:
    """Generate random OpenQASM 3 program.

    Args:
        num_qubits (int): Number of quantum wires.
        depth (int): Layers of operations (i.e., critical path length).
        max_operands (int): Maximum size of gate for each operation.
        seed (int): Seed for random number generator.
        measure (bool): Whether to include measurement gates.

    Raises:
        ValueError: When invalid random circuit options are given.
        QbraidError: When failed to create random OpenQASM 3 program.

    Returns:
        str: OpenQASM 3 program.
    """
    return qasm3_random_from_gates(
        create_gateset_qasm, num_qubits, depth, max_operands, seed, measure
    )
