from decimal import Decimal

import pytest

import bloqade.analog.ir.scalar as scalar
import bloqade.analog.ir.control.waveform as waveform
from bloqade.analog import var, cast, start, piecewise_linear
from bloqade.analog.ir import (
    Field,
    Pulse,
    Uniform,
    Sequence,
    AnalogCircuit,
    AssignedRunTimeVector,
    rydberg,
    detuning,
)
from bloqade.analog.atom_arrangement import Chain
from bloqade.analog.compiler.analysis.common.assignment_scan import AssignmentScan
from bloqade.analog.compiler.rewrite.common.assign_variables import AssignBloqadeIR


def test_assignment():
    lattice = Chain(2, lattice_spacing=4.5)
    circuit = (
        lattice.rydberg.detuning.scale("amp")
        .piecewise_linear([0.1, 0.5, 0.1], [1.0, 2.0, 3.0, 4.0])
        .parse_circuit()
    )

    amp = 2 * [Decimal("1.0")]
    circuit = AssignBloqadeIR(dict(amp=amp)).visit(circuit)

    target_circuit = AnalogCircuit(
        lattice,
        Sequence(
            {
                rydberg: Pulse(
                    {
                        detuning: Field(
                            drives={
                                AssignedRunTimeVector("amp", amp): piecewise_linear(
                                    [0.1, 0.5, 0.1], [1.0, 2.0, 3.0, 4.0]
                                )
                            }
                        ),
                    }
                )
            }
        ),
    )

    assert circuit == target_circuit


def test_assignment_error():
    lattice = Chain(2, lattice_spacing=4.5)
    circuit = (
        lattice.rydberg.detuning.scale("amp")
        .piecewise_linear([0.1, 0.5, 0.1], [1.0, 2.0, 3.0, 4.0])
        .parse_circuit()
    )

    amp = 2 * [Decimal("1.0")]
    circuit = AssignBloqadeIR(dict(amp=amp)).visit(circuit)
    with pytest.raises(ValueError):
        circuit = AssignBloqadeIR(dict(amp=amp)).visit(circuit)


def test_scan():
    t = var("t")
    circuit = (
        start.rydberg.detuning.uniform.constant("max", 1.0)
        .slice(0, t)
        .record("detuning")
        .linear("detuning", 0, 1.0 - t)
        .parse_sequence()
    )

    params = dict(max=10, t=0.1)

    completed_params = AssignmentScan(params).scan(circuit)
    completed_circuit = AssignBloqadeIR(completed_params).visit(circuit)

    t_assigned = scalar.AssignedVariable("t", 0.1)
    max_assigned = scalar.AssignedVariable("max", 10)
    detuning_assigned = scalar.AssignedVariable("detuning", 10)
    dur_assigned = 1 - t_assigned

    interval = waveform.Interval(cast(0), t_assigned)

    target_circuit = Sequence(
        {
            rydberg: Pulse(
                {
                    detuning: Field(
                        drives={
                            Uniform: waveform.Append(
                                [
                                    waveform.Slice(
                                        waveform.Constant(max_assigned, cast(1.0)),
                                        interval,
                                    ),
                                    waveform.Linear(detuning_assigned, 0, dur_assigned),
                                ]
                            )
                        }
                    )
                }
            )
        }
    )

    print(repr(completed_circuit))
    print(repr(target_circuit))

    assert completed_circuit == target_circuit
