from dataclasses import dataclass

from pytket._tket.circuit import Circuit

from tket.circuit import (
    Tk2Circuit,
    render_circuit_dot,
)
from tket.ops import TketOp


@dataclass
class CustomCost:
    gate_count: int
    h_count: int

    def __add__(self, other):
        return CustomCost(
            self.gate_count + other.gate_count, self.h_count + other.h_count
        )


def test_cost():
    circ = Tk2Circuit(Circuit(4).CX(0, 1).H(1).CX(1, 2).CX(0, 3).H(0))

    print(circ.circuit_cost(lambda op: int(op == TketOp.CX)))

    assert circ.circuit_cost(lambda op: int(op == TketOp.CX)) == 3
    assert circ.circuit_cost(lambda op: CustomCost(1, op == TketOp.H)) == CustomCost(
        5, 2
    )


def test_hash():
    circA = Tk2Circuit(Circuit(4).CX(0, 1).CX(1, 2).CX(0, 3))
    circB = Tk2Circuit(Circuit(4).CX(1, 2).CX(0, 1).CX(0, 3))
    circC = Tk2Circuit(Circuit(4).CX(0, 1).CX(0, 3).CX(1, 2))

    assert hash(circA) != hash(circB)
    assert hash(circA) == hash(circC)


def test_conversion():
    tk1 = Circuit(4).CX(0, 2).CX(1, 2).CX(1, 3)
    tk1_dot = render_circuit_dot(tk1)

    tk2 = Tk2Circuit(tk1)
    tk2_dot = render_circuit_dot(tk2)

    assert type(tk2) is Tk2Circuit
    assert tk1_dot == tk2_dot

    tk1_back = tk2.to_tket1()

    assert tk1_back == tk1
    assert type(tk1_back) is Circuit


def test_conversion_qsystem():
    tk1 = Circuit(2).ZZPhase(0.75, 0, 1).PhasedX(0.25, 0.33, 1)
    tk1_dot = render_circuit_dot(tk1)

    tk2 = Tk2Circuit(tk1)
    tk2_dot = render_circuit_dot(tk2)

    assert type(tk2) is Tk2Circuit
    assert tk1_dot == tk2_dot

    # Check that we didn't use the opaque tk1 op fallback.
    # TODO: There's no way to traverse the circuit nodes?
    assert "TKET1.tk1op" not in tk1_dot
    assert "tket.qsystem.PhasedX" in tk1_dot
    assert "tket.qsystem.ZZPhase" in tk1_dot

    tk1_back = tk2.to_tket1()

    assert tk1_back == tk1
    assert type(tk1_back) is Circuit
