import math
import networkx as nx
import numpy as np
from numba import njit, prange
IS_CUDA_AVAILABLE = True
try:
    from numba import cuda, types
    from numba.core.errors import NumbaPerformanceWarning
    import warnings

    warnings.simplefilter('ignore', category=NumbaPerformanceWarning)

    @cuda.jit(device=True)
    def cuda_probability_by_hamming_weight(q, J, h, z, theta, t, n_qubits):
        # critical angle
        theta_c = np.arcsin(max(-1.0, min(1.0, abs(h) / (z * J))))

        p = (
            pow(2.0, abs(J / h) - 1.0)
            * (1.0 + math.sin(theta - theta_c) * math.cos(1.5 * math.pi * J * t + theta) / (1.0 + math.sqrt(t)))
            - 0.5
        )

        if (p * (n_qubits + 2.0)) >= 1024.0:
            return 0.0

        result = (pow(2.0, (n_qubits + 2.0) * p) - 1.0) * pow(2.0, -((n_qubits + 1.0) * p) - p * q) / (pow(2.0, p) - 1.0)

        if math.isnan(result) or math.isinf(result):
            return 0.0

        return result

    @cuda.jit
    def cuda_maxcut_hamming_cdf(delta_t, tot_t, h_mult, J_func, degrees, theta, hamming_prob):
        step = cuda.blockIdx.x
        qi = cuda.blockIdx.y
        J_eff = J_func[qi]
        z = degrees[qi]
        if abs(z * J_eff) <= (2 ** (-54)):
            return

        n_qubits = cuda.gridDim.y
        theta_eff = theta[qi]
        t = step * delta_t
        tm1 = (step - 1) * delta_t
        h_t = h_mult * (tot_t - t)

        qo = cuda.threadIdx.x
        if J_eff > 0.0:
            qo = cuda.blockDim.x - (1 + qo)
        diff = cuda_probability_by_hamming_weight(qo, J_eff, h_t, z, theta_eff, t, n_qubits)
        diff -= cuda_probability_by_hamming_weight(qo, J_eff, h_t, z, theta_eff, tm1, n_qubits)
        hamming_prob[qo] += diff

except:
    IS_CUDA_AVAILABLE = False


@njit
def probability_by_hamming_weight(J, h, z, theta, t, n_qubits):
    bias = np.zeros(n_qubits - 1, dtype=np.float64)

    # critical angle
    theta_c = np.arcsin(
        max(
            -1.0,
            min(
                1.0,
                (1.0 if J > 0.0 else -1.0) if np.isclose(abs(z * J), 0.0) else (abs(h) / (z * J)),
            ),
        )
    )

    p = (
        pow(2.0, abs(J / h) - 1.0)
        * (1.0 + np.sin(theta - theta_c) * np.cos(1.5 * np.pi * J * t + theta) / (1.0 + np.sqrt(t)))
        - 0.5
    )

    if (p * n_qubits) >= 1024:
        return bias

    tot_n = 1.0 + 1.0 / pow(2.0, p * n_qubits)
    for q in range(1, n_qubits):
        n = 1.0 / pow(2.0, p * q)
        bias[q - 1] = n
        tot_n += n
    bias /= tot_n

    if J > 0.0:
        return bias[::-1]

    return bias


@njit(parallel=True)
def maxcut_hamming_cdf(n_qubits, J_func, degrees, quality, hamming_prob):
    if n_qubits < 2:
        hamming_prob.fill(0.0)
        return

    n_steps = 1 << quality
    delta_t = 1.0 / n_steps
    tot_t = n_steps * delta_t
    h_mult = 32.0 / tot_t
    n_bias = n_qubits - 1

    theta = np.zeros(n_qubits)
    for q in prange(n_qubits):
        J = J_func[q]
        z = degrees[q]
        theta[q] = np.arcsin(
            max(
                -1.0,
                min(
                    1.0,
                    (1.0 if J > 0.0 else -1.0) if np.isclose(abs(z * J), 0.0) else (abs(h_mult) / (z * J)),
                ),
            )
        )

    for qc in prange(n_qubits, n_steps * n_qubits):
        step = qc // n_qubits
        q = qc % n_qubits
        J_eff = J_func[q]
        if np.isclose(abs(J_eff), 0.0):
            continue
        z = degrees[q]
        theta_eff = theta[q]
        t = step * delta_t
        tm1 = (step - 1) * delta_t
        h_t = h_mult * (tot_t - t)
        bias = probability_by_hamming_weight(J_eff, h_t, z, theta_eff, t, n_qubits)
        last_bias = probability_by_hamming_weight(J_eff, h_t, z, theta_eff, tm1, n_qubits)
        for i in range(n_bias):
            hamming_prob[i] += bias[i] - last_bias[i]

    tot_prob = sum(hamming_prob)
    hamming_prob /= tot_prob

    tot_prob = 0.0
    for i in range(n_bias):
        tot_prob += hamming_prob[i]
        hamming_prob[i] = tot_prob
    hamming_prob[-1] = 2.0


# Written by Elara (OpenAI custom GPT)
@njit
def local_repulsion_choice(adjacency, degrees, weights, n, m):
    """

    Pick m nodes out of n with repulsion bias:
    - High-degree nodes are already less likely
    - After choosing a node, its neighbors' probabilities are further reduced
    adjacency: 2D int array (n x max_deg), padded with -1

    degrees: int array of shape (n,)
    weights: float64 array of shape (n,)
    """

    weights = weights.copy()
    chosen = np.zeros(m, dtype=np.int32)   # store chosen indices
    available = np.ones(n, dtype=np.int32) # 1 = available, 0 = not
    mask = np.zeros(n, dtype=np.bool_)
    chosen_count = 0

    for _ in range(m):
        # Count available
        total_w = 0.0
        for i in range(n):
            if available[i] == 1:
                total_w += weights[i]
        if total_w <= 0:
            break

        # Normalize & sample
        r = np.random.rand()
        cum = 0.0
        node = -1
        for i in range(n):
            if available[i] == 1:
                cum += weights[i] / total_w
                if r < cum:
                    node = i
                    break

        if node == -1:
            continue

        # Select node
        chosen[chosen_count] = node
        chosen_count += 1
        available[node] = 0
        mask[node] = 1

        # Repulsion: penalize neighbors
        deg = degrees[node]
        for j in range(deg):
            nbr = adjacency[node, j]
            if nbr >= 0 and available[nbr] == 1:
                weights[nbr] *= 0.5  # tunable penalty factor

    return mask


@njit(parallel=True)
def local_repulsion_choice_sample(shots, thresholds, adjacency, degrees, weights, n):
    samples = np.zeros((shots, n), dtype=np.bool_)  # (shots × n) boolean mask array

    for s in prange(shots):
        # First dimension: Hamming weight
        mag_prob = np.random.random()
        m = 0
        while thresholds[m] < mag_prob:
            m += 1
        m += 1

        # Second dimension: permutation within Hamming weight
        samples[s, :] = local_repulsion_choice(adjacency, degrees, weights, n, m)

    return samples


def mask_array_to_python_ints(masks):
    samples = []
    for mask in masks:
        sample = 0
        for b in reversed(mask):
            sample <<= 1
            if b:
                sample |= 1
        samples.append(sample)

    return samples


def evaluate_cut_edges(samples, edge_keys, edge_values):
    best_value = float("-inf")
    best_solution = None
    best_cut_edges = None

    for state in samples:
        cut_edges = []
        cut_value = 0
        for i in range(len(edge_values)):
            k = i << 1
            u, v = edge_keys[k], edge_keys[k + 1]
            if ((state >> u) & 1) != ((state >> v) & 1):
                cut_value += edge_values[i]

        if cut_value > best_value:
            best_value = cut_value
            best_solution = state

    return best_solution, float(best_value)


# By Gemini (Google Search AI)
def int_to_bitstring(integer, length):
    return (bin(integer)[2:].zfill(length))[::-1]


def maxcut_tfim(
    G,
    quality=None,
    shots=None,
):
    # Number of qubits/nodes
    nodes = list(G.nodes())
    n_qubits = len(nodes)

    if n_qubits == 0:
        return "", 0, ([], [])

    if n_qubits == 1:
        return "0", 0, ([nodes[0]], [])

    if n_qubits == 2:
        ed = G.get_edge_data(nodes[0], nodes[1], default={})
        if ed == {}:
            return "01", 0, ([nodes[0]], [nodes[1]])

        weight = ed.get("weight", 1.0)
        if weight < 0.0:
            return "00", 0, (nodes, [])

        return "01", weight, ([nodes[0]], [nodes[1]])

    # Warp size is 32:
    group_size = n_qubits - 1

    if quality is None:
        quality = 10

    if shots is None:
        # Number of measurement shots
        shots = n_qubits << quality

    n_steps = 1 << quality
    grid_size = n_steps * n_qubits
    grid_dims = (n_steps, n_qubits)

    J_eff = np.array(
        [
            -sum(edge_attributes.get("weight", 1.0) for _, edge_attributes in G.adj[n].items())
            for n in nodes
        ],
        dtype=np.float64,
    )
    degrees = np.array(
        [
            sum(abs(edge_attributes.get("weight", 1.0)) for _, edge_attributes in G.adj[n].items())
            for n in nodes
        ],
        dtype=np.float64,
    )
    # thresholds = tfim_sampler._maxcut_hamming_cdf(n_qubits, J_eff, degrees, quality)

    n_bias = n_qubits - 1
    thresholds = np.zeros(n_bias, dtype=np.float64)
    tot_prob = 0
    p = 1.0
    if n_qubits & 1:
        q = n_qubits // 2
        thresholds[q - 1] = p
        tot_prob = p
        p /= 2
    for q in range(1, n_qubits // 2):
        thresholds[q - 1] = p
        thresholds[n_bias - q] = p
        tot_prob += 2 * p
        p /= 2
    thresholds /= tot_prob

    if IS_CUDA_AVAILABLE and cuda.is_available() and grid_size >= 128:
        delta_t = 1.0 / n_steps
        tot_t = n_steps * delta_t
        h_mult = 32.0 / tot_t

        theta = np.zeros(n_qubits)
        for q in range(n_qubits):
            J = J_eff[q]
            z = degrees[q]
            theta[q] = np.arcsin(
                max(
                    -1.0,
                    min(
                        1.0,
                        (1.0 if J > 0.0 else -1.0) if np.isclose(abs(z * J), 0.0) else (abs(h_mult) / (z * J)),
                    ),
                )
            )

        cuda_maxcut_hamming_cdf[grid_dims, group_size](delta_t, tot_t, h_mult, J_eff, degrees, theta, thresholds)

        tot_prob = sum(thresholds)
        thresholds /= tot_prob

        tot_prob = 0.0
        for i in range(n_bias):
            tot_prob += thresholds[i]
            thresholds[i] = tot_prob
        thresholds[-1] = 2.0
    else:
        maxcut_hamming_cdf(n_qubits, J_eff, degrees, quality, thresholds)

    G_dict = nx.to_dict_of_lists(G)
    max_degree = max(len(x) for x in G_dict.values())
    adjacency = np.full((len(nodes), max_degree), -1, dtype=np.int32)
    for i in range(len(nodes)):
        adj = G_dict.get(nodes[i], [])
        for j in range(len(adj)):
            adjacency[i, j] = nodes.index(adj[j])

    J_max = max(J_eff)
    weights = 1.0 / (1.0 + (J_max - J_eff))
    # We only need unique instances
    samples = list(set(mask_array_to_python_ints(local_repulsion_choice_sample(shots, thresholds, adjacency, degrees, weights, n_qubits))))

    edge_keys = []
    edge_values = []
    for u, v, data in G.edges(data=True):
        edge_keys.append(nodes.index(u))
        edge_keys.append(nodes.index(v))
        edge_values.append(data.get("weight", 1.0))

    best_solution, best_value = evaluate_cut_edges(samples, edge_keys, edge_values)

    bit_string = int_to_bitstring(best_solution, n_qubits)
    bit_list = list(bit_string)
    l, r = [], []
    for i in range(len(bit_list)):
        b = bit_list[i] == "1"
        if b:
            r.append(nodes[i])
        else:
            l.append(nodes[i])

    return bit_string, best_value, (l, r)
