from .spin_glass_solver import spin_glass_solver
import networkx as nx


def get_best_stitch(adjacency, terminals_a, terminals_b, is_cyclic):
    best_weight = float("inf")
    best_edge = None
    for a in range(2):
        a_term = terminals_a[a]
        for b in range(2):
            b_term = terminals_b[b]
            weight = adjacency[a_term][b_term]["weight"]
            if is_cyclic:
                n_a_term = terminals_a[0 if a else 1]
                n_b_term = terminals_b[0 if b else 1]
                weight += adjacency[n_a_term][n_b_term]["weight"]
            if weight < best_weight:
                best_weight = weight
                best_edge = (a, b)

    return best_weight, best_edge


def tsp_symmetric(G, quality=0, shots=None, correction_quality=2, is_cyclic=True, start_node=None):
    nodes = list(G.nodes())
    n_nodes = len(nodes)

    if n_nodes == 0:
        return ([], 0)
    if n_nodes == 1:
        return ([nodes[0]], 0)
    if n_nodes == 2:
        return ([nodes[0], nodes[1]], G[nodes[0]][nodes[1]].get("weight", 1.0))

    a = []
    b = []
    if not (start_node is None):
        a = [start_node]
        b = nodes
        b.remove(start_node)
    else:
        while (len(a) == 0) or (len(b) == 0):
            bits = ''
            _, _, bits, _ = spin_glass_solver(G, quality=quality, shots=shots, correction_quality=correction_quality)
            a = list(bits[0])
            b = list(bits[1])

    G_a = nx.Graph()
    G_b = nx.Graph()
    G_a.add_nodes_from(a)
    G_b.add_nodes_from(b)
    for u, v, data in G.edges(data=True):
        if (u in a) and (v in a):
            G_a.add_edge(u, v, weight=data.get("weight", 1.0))
            continue

        if (u in b) and (v in b):
            G_b.add_edge(u, v, weight=data.get("weight", 1.0))

    sol_a = tsp_symmetric(G_a, quality=quality, is_cyclic=False)
    sol_b = tsp_symmetric(G_b, quality=quality, is_cyclic=False)

    path_a = sol_a[0]
    path_b = sol_b[0]

    sol_weight = sol_a[1] + sol_b[1]

    single = None
    is_single_a = len(path_a) == 1
    is_single_b = len(path_b) == 1

    if is_single_a and is_single_b:
        return (path_a + path_b, sol_weight + G[path_a[0]][path_b[0]].get("weight", 1.0))

    singlet = None
    bulk = None
    if is_single_a:
        singlet = path_a[0]
        bulk = path_b
    elif is_single_b:
        singlet = path_b[0]
        bulk = path_a

    if not singlet is None:
        best_weight = G[singlet][bulk[0]].get("weight", 1.0)
        best_path = [singlet] + bulk
        weight = G[singlet][bulk[-1]].get("weight", 1.0)
        if weight < best_weight:
            best_weight = weight
            best_path = bulk + [singlet]
        for i in range(len(bulk) - 1):
            weight = (
                G[singlet][bulk[i]].get("weight", 1.0) +
                G[singlet][bulk[i + 1]].get("weight", 1.0) -
                G[bulk[i]][bulk[i + 1]].get("weight", 1.0)
            )
            if weight < best_weight:
                best_weight = weight
                best_path = bulk.copy().insert(singlet, i + 1)

        return (best_path, sol_weight + best_weight)

    terminals_a = [path_a[0], path_a[-1]]
    terminals_b = [path_b[0], path_b[-1]]

    for _ in range(2):
        for _ in range(2):
            best_weight = G[terminals_a[1]][terminals_b[0]].get("weight", 1.0)
            best_path = path_a + path_b
            weight = G[terminals_b[1]][terminals_a[0]].get("weight", 1.0)
            if weight < best_weight:
                best_weight = weight
                best_path = path_b + path_a
            for i in range(len(path_b) - 1):
                weight = (
                    G[terminals_a[0]][path_b[i]].get("weight", 1.0) +
                    G[terminals_a[1]][path_b[i + 1]].get("weight", 1.0) -
                    G[path_b[i]][path_b[i + 1]].get("weight", 1.0)
                )
                if weight < best_weight:
                    best_weight = weight
                    best_path = path_b.copy()
                    best_path[i + 1:i + 1] = path_a
            path_a.reverse()
            terminals_a.reverse()
        path_a, path_b = path_b, path_a
        terminals_a, terminals_b = terminals_b, terminals_a

    return (best_path, sol_weight + best_weight)
