"""Tests for coarse retrieval matching."""

import tempfile
from pathlib import Path

import numpy as np
import pandas as pd
import pytest

from mtbsync.match.retrieval import (
    load_reference_index,
    match_frame_to_reference,
    retrieve_coarse_pairs,
    save_pairs_csv,
)


class TestLoadReferenceIndex:
    """Tests for load_reference_index function."""

    def test_load_valid_index(self):
        """Test loading a valid reference index."""
        with tempfile.TemporaryDirectory() as tmpdir:
            index_path = Path(tmpdir) / "test_index.npz"

            # Create test index
            t_ref = np.array([0.0, 1.0, 2.0])
            kpts = [
                np.array([[10.0, 20.0, 5.0, 45.0, 0.8]], dtype=np.float32),
                np.array([[15.0, 25.0, 5.5, 60.0, 0.7]], dtype=np.float32),
                np.zeros((0, 5), dtype=np.float32),
            ]
            desc = [
                np.random.randint(0, 256, (1, 32), dtype=np.uint8),
                np.random.randint(0, 256, (1, 32), dtype=np.uint8),
                np.zeros((0, 32), dtype=np.uint8),
            ]

            np.savez_compressed(
                str(index_path),
                t_ref=t_ref,
                kpts=np.array(kpts, dtype=object),
                desc=np.array(desc, dtype=object),
                meta="{}",
            )

            # Load index
            loaded_t, loaded_kpts, loaded_desc = load_reference_index(str(index_path))

            assert len(loaded_t) == 3
            assert len(loaded_kpts) == 3
            assert len(loaded_desc) == 3
            assert np.array_equal(loaded_t, t_ref)

    def test_load_nonexistent_file(self):
        """Test error handling for nonexistent file."""
        with pytest.raises(RuntimeError, match="Failed to load reference index"):
            load_reference_index("nonexistent.npz")

    def test_load_invalid_format(self):
        """Test error handling for invalid index format."""
        with tempfile.TemporaryDirectory() as tmpdir:
            index_path = Path(tmpdir) / "invalid.npz"
            # Missing required keys
            np.savez_compressed(str(index_path), data=np.array([1, 2, 3]))

            with pytest.raises(RuntimeError, match="Invalid index file format"):
                load_reference_index(str(index_path))


class TestMatchFrameToReference:
    """Tests for match_frame_to_reference function."""

    def test_empty_new_descriptors(self):
        """Test handling of empty new descriptors."""
        ref_desc = [np.random.randint(0, 256, (10, 32), dtype=np.uint8)]
        ref_timestamps = np.array([0.0])

        result = match_frame_to_reference(
            np.zeros((0, 32), dtype=np.uint8), ref_desc, ref_timestamps
        )

        assert len(result) == 0

    def test_empty_reference_descriptors(self):
        """Test handling of empty reference descriptors."""
        new_desc = np.random.randint(0, 256, (10, 32), dtype=np.uint8)
        ref_desc = [np.zeros((0, 32), dtype=np.uint8)]
        ref_timestamps = np.array([0.0])

        result = match_frame_to_reference(new_desc, ref_desc, ref_timestamps)

        assert len(result) == 0

    def test_matching_returns_top_k(self):
        """Test that matching returns at most top-K results."""
        # Create identical descriptors for perfect matches
        desc = np.random.randint(0, 256, (10, 32), dtype=np.uint8)

        new_desc = desc.copy()
        ref_desc = [desc.copy() for _ in range(5)]
        ref_timestamps = np.array([0.0, 1.0, 2.0, 3.0, 4.0])

        result = match_frame_to_reference(new_desc, ref_desc, ref_timestamps, top_k=3)

        # Should return exactly top_k=3 matches
        assert len(result) == 3

        # Each result should have (t_ref, score, n_matches)
        for t_ref, score, n_matches in result:
            assert isinstance(t_ref, (float, np.floating))
            assert isinstance(score, (float, np.floating))
            assert isinstance(n_matches, (int, np.integer))
            assert score > 0
            assert n_matches > 0

    def test_similar_descriptors_match(self):
        """Test that similar descriptors produce matches."""
        # Create similar descriptors (with small variations)
        base_desc = np.random.randint(0, 256, (10, 32), dtype=np.uint8)

        new_desc = base_desc.copy()
        # Add small noise to reference
        ref_desc_1 = base_desc.copy()
        ref_desc_1[0, 0] = (ref_desc_1[0, 0] + 1) % 256

        ref_desc = [ref_desc_1]
        ref_timestamps = np.array([0.0])

        result = match_frame_to_reference(new_desc, ref_desc, ref_timestamps)

        # Should find matches despite small difference
        assert len(result) > 0


class TestRetrieveCoarsePairs:
    """Tests for retrieve_coarse_pairs function."""

    def test_basic_retrieval(self):
        """Test basic retrieval with valid inputs."""
        # Create test data with similar descriptors to ensure matches
        base_desc = np.random.randint(0, 256, (10, 32), dtype=np.uint8)

        # New frames use the same descriptors (will match well)
        new_timestamps = np.array([0.0, 1.0])
        new_desc = [base_desc.copy(), base_desc.copy()]

        # Reference frames also use similar descriptors
        ref_timestamps = np.array([0.0, 1.0, 2.0])
        ref_desc = [base_desc.copy(), base_desc.copy(), base_desc.copy()]

        result_df, timings = retrieve_coarse_pairs(
            new_timestamps, new_desc, ref_timestamps, ref_desc, top_k=2
        )

        # Should be a DataFrame
        assert isinstance(result_df, pd.DataFrame)
        # Should have timing information
        assert isinstance(timings, dict)

        # Should have required columns
        assert "t_new" in result_df.columns
        assert "t_ref" in result_df.columns
        assert "score" in result_df.columns
        assert "n_matches" in result_df.columns

        # Should have some pairs
        assert len(result_df) > 0

        # t_new values should be from new_timestamps
        assert all(t in new_timestamps for t in result_df["t_new"].unique())

    def test_monotonic_t_new(self):
        """Test that output is sorted by monotonic t_new."""
        # Use similar descriptors to ensure matches
        base_desc = np.random.randint(0, 256, (10, 32), dtype=np.uint8)

        new_timestamps = np.array([0.0, 1.0, 2.0])
        new_desc = [base_desc.copy() for _ in range(3)]

        ref_timestamps = np.array([0.0, 1.0, 2.0, 3.0])
        ref_desc = [base_desc.copy() for _ in range(4)]

        result_df, timings = retrieve_coarse_pairs(
            new_timestamps, new_desc, ref_timestamps, ref_desc
        )

        # Check that t_new is sorted (monotonic non-decreasing)
        t_new_values = result_df["t_new"].values
        assert all(t_new_values[i] <= t_new_values[i + 1] for i in range(len(t_new_values) - 1))

    def test_skip_empty_descriptors(self):
        """Test that frames with empty descriptors are skipped."""
        # Use similar descriptors for matching
        base_desc = np.random.randint(0, 256, (10, 32), dtype=np.uint8)

        new_timestamps = np.array([0.0, 1.0])
        new_desc = [
            base_desc.copy(),
            np.zeros((0, 32), dtype=np.uint8),  # Empty
        ]

        ref_timestamps = np.array([0.0, 1.0])
        ref_desc = [base_desc.copy(), base_desc.copy()]

        result_df, timings = retrieve_coarse_pairs(
            new_timestamps, new_desc, ref_timestamps, ref_desc
        )

        # Should only have pairs for the first new frame
        unique_new_t = result_df["t_new"].unique()
        assert len(unique_new_t) == 1
        assert unique_new_t[0] == 0.0

    def test_no_matches_raises_error(self):
        """Test that having no matches raises an error."""
        new_timestamps = np.array([0.0])
        new_desc = [np.zeros((0, 32), dtype=np.uint8)]  # Empty

        ref_timestamps = np.array([0.0])
        ref_desc = [np.random.randint(0, 256, (5, 32), dtype=np.uint8)]

        with pytest.raises(RuntimeError, match="No valid matches found"):
            retrieve_coarse_pairs(new_timestamps, new_desc, ref_timestamps, ref_desc)

    def test_length_mismatch_raises_error(self):
        """Test that mismatched lengths raise an error."""
        new_timestamps = np.array([0.0, 1.0])
        new_desc = [np.random.randint(0, 256, (5, 32), dtype=np.uint8)]  # Only 1

        ref_timestamps = np.array([0.0])
        ref_desc = [np.random.randint(0, 256, (5, 32), dtype=np.uint8)]

        with pytest.raises(ValueError, match="length mismatch"):
            retrieve_coarse_pairs(new_timestamps, new_desc, ref_timestamps, ref_desc)


class TestSavePairsCsv:
    """Tests for save_pairs_csv function."""

    def test_save_and_load_csv(self):
        """Test saving and loading pairs CSV."""
        with tempfile.TemporaryDirectory() as tmpdir:
            csv_path = Path(tmpdir) / "pairs.csv"

            # Create test DataFrame
            df = pd.DataFrame(
                {
                    "t_new": [0.0, 0.0, 1.0],
                    "t_ref": [0.0, 1.0, 1.0],
                    "score": [0.9, 0.8, 0.85],
                    "n_matches": [10, 8, 9],
                }
            )

            # Save
            save_pairs_csv(df, str(csv_path))

            # Load and verify
            loaded_df = pd.read_csv(csv_path)

            assert len(loaded_df) == 3
            assert "t_new" in loaded_df.columns
            assert "t_ref" in loaded_df.columns
            assert "score" in loaded_df.columns
            assert "n_matches" in loaded_df.columns

            # Check values are preserved
            assert np.allclose(loaded_df["t_new"], df["t_new"])
            assert np.allclose(loaded_df["t_ref"], df["t_ref"])
            assert np.allclose(loaded_df["score"], df["score"], rtol=1e-5)
            assert all(loaded_df["n_matches"] == df["n_matches"])
