import numpy as np
import time
from mtbsync.match.retrieval import retrieve_coarse_pairs


def test_parallel_timing_fields():
    """Test that retrieve_coarse_pairs returns timing information."""
    # Create minimal fake descriptors for testing
    n_ref = 3
    n_new = 2

    # Create simple descriptors (ORB descriptors are Nx32 uint8)
    base_desc = np.random.randint(0, 256, (10, 32), dtype=np.uint8)
    ref_desc_list = [base_desc.copy() for _ in range(n_ref)]
    new_desc_list = [base_desc.copy() for _ in range(n_new)]

    ref_timestamps = np.array([0.0, 1.0, 2.0], dtype=np.float64)
    new_timestamps = np.array([0.5, 1.5], dtype=np.float64)

    t0 = time.time()
    df, timings = retrieve_coarse_pairs(
        new_timestamps,
        new_desc_list,
        ref_timestamps,
        ref_desc_list,
        top_k=2,
        warp_enable=False,  # Disable warp to keep test simple
        threads=1,  # Sequential for reproducibility
    )
    elapsed = time.time() - t0

    # Verify timing dict structure
    assert "retrieval_sec" in timings
    assert "warp_sec" in timings
    assert "markers_sec" in timings
    assert "total_sec" in timings

    # Verify all timings are non-negative
    assert timings["retrieval_sec"] >= 0.0
    assert timings["warp_sec"] >= 0.0
    assert timings["markers_sec"] >= 0.0
    assert timings["total_sec"] >= 0.0

    # Verify total timing is reasonable
    assert elapsed >= 0.0

    # Verify DataFrame was returned
    assert df is not None
    assert len(df) > 0


def test_parallel_vs_sequential():
    """Test that threading parameter works without errors."""
    # Create minimal fake descriptors
    n_ref = 4
    n_new = 4

    base_desc = np.random.randint(0, 256, (15, 32), dtype=np.uint8)
    ref_desc_list = [base_desc.copy() for _ in range(n_ref)]
    new_desc_list = [base_desc.copy() for _ in range(n_new)]

    ref_timestamps = np.linspace(0, 3, n_ref)
    new_timestamps = np.linspace(0.5, 3.5, n_new)

    # Sequential run
    df_seq, timings_seq = retrieve_coarse_pairs(
        new_timestamps,
        new_desc_list,
        ref_timestamps,
        ref_desc_list,
        top_k=2,
        warp_enable=False,
        threads=1,
    )

    # Parallel run
    df_par, timings_par = retrieve_coarse_pairs(
        new_timestamps,
        new_desc_list,
        ref_timestamps,
        ref_desc_list,
        top_k=2,
        warp_enable=False,
        threads=2,
    )

    # Both should return valid results
    assert len(df_seq) > 0
    assert len(df_par) > 0

    # Both should have timing info
    assert "total_sec" in timings_seq
    assert "total_sec" in timings_par
