"""Tests for ORB descriptor extraction."""

import json
import tempfile
from pathlib import Path

import numpy as np
import pytest

from mtbsync.features.descriptors import compute_orb_descriptors, save_reference_index


class TestComputeOrbDescriptors:
    """Tests for compute_orb_descriptors function."""

    def test_empty_frames_list(self):
        """Test error handling for empty frames list."""
        with pytest.raises(ValueError, match="Frames list is empty"):
            compute_orb_descriptors([])

    def test_invalid_frame(self):
        """Test error handling for None frame."""
        frames = [np.zeros((100, 100, 3), dtype=np.uint8), None]
        with pytest.raises(ValueError, match="Frame 1 is invalid"):
            compute_orb_descriptors(frames)

    def test_output_shapes(self):
        """Test that output arrays have correct shapes."""
        # Create a simple checkerboard pattern (feature-rich)
        frame = np.zeros((480, 640, 3), dtype=np.uint8)
        # Create checkerboard
        square_size = 40
        for i in range(0, 480, square_size):
            for j in range(0, 640, square_size):
                if (i // square_size + j // square_size) % 2 == 0:
                    frame[i : i + square_size, j : j + square_size] = 255

        kpts_list, desc_list = compute_orb_descriptors([frame], n_features=500)

        assert len(kpts_list) == 1
        assert len(desc_list) == 1

        kpts = kpts_list[0]
        desc = desc_list[0]

        # Check shapes
        if len(kpts) > 0:
            assert kpts.shape[1] == 5  # [x, y, size, angle, response]
            assert desc.shape[1] == 32  # ORB descriptor size
            assert len(kpts) == len(desc)  # Same number of keypoints and descriptors

        # Check dtypes
        assert kpts.dtype == np.float32
        assert desc.dtype == np.uint8

    def test_blank_vs_textured(self):
        """Test that textured image gets more keypoints than blank image."""
        # Blank image
        blank = np.zeros((480, 640, 3), dtype=np.uint8)

        # Textured image (checkerboard)
        textured = np.zeros((480, 640, 3), dtype=np.uint8)
        square_size = 40
        for i in range(0, 480, square_size):
            for j in range(0, 640, square_size):
                if (i // square_size + j // square_size) % 2 == 0:
                    textured[i : i + square_size, j : j + square_size] = 255

        frames = [blank, textured]
        kpts_list, desc_list = compute_orb_descriptors(frames, n_features=1000)

        blank_kpts = len(kpts_list[0])
        textured_kpts = len(kpts_list[1])

        # Textured image should have significantly more keypoints
        assert textured_kpts > blank_kpts
        assert textured_kpts > 50  # Should detect many corners

    def test_empty_keypoints_handling(self):
        """Test that frames with no keypoints return empty arrays with correct shape."""
        # Completely blank image - should have 0 keypoints
        blank = np.zeros((100, 100, 3), dtype=np.uint8)

        kpts_list, desc_list = compute_orb_descriptors([blank], n_features=500)

        kpts = kpts_list[0]
        desc = desc_list[0]

        # Should be empty but with correct second dimension
        assert kpts.shape == (0, 5)
        assert desc.shape == (0, 32)
        assert kpts.dtype == np.float32
        assert desc.dtype == np.uint8

    def test_clahe_option(self):
        """Test CLAHE preprocessing option."""
        # Create a low-contrast image
        frame = np.ones((480, 640, 3), dtype=np.uint8) * 128
        # Add some subtle features
        frame[100:150, 100:150] = 140
        frame[200:250, 200:250] = 115

        # With CLAHE
        kpts_clahe, _ = compute_orb_descriptors([frame], n_features=500, use_clahe=True)

        # Without CLAHE
        kpts_no_clahe, _ = compute_orb_descriptors([frame], n_features=500, use_clahe=False)

        # Both should work without errors
        assert len(kpts_clahe) == 1
        assert len(kpts_no_clahe) == 1

    def test_grayscale_input(self):
        """Test that grayscale images work correctly."""
        # Create grayscale checkerboard
        gray = np.zeros((480, 640), dtype=np.uint8)
        square_size = 40
        for i in range(0, 480, square_size):
            for j in range(0, 640, square_size):
                if (i // square_size + j // square_size) % 2 == 0:
                    gray[i : i + square_size, j : j + square_size] = 255

        kpts_list, desc_list = compute_orb_descriptors([gray], n_features=500)

        assert len(kpts_list) == 1
        assert len(desc_list) == 1
        assert len(kpts_list[0]) > 0  # Should find keypoints

    def test_multiple_frames(self):
        """Test processing multiple frames."""
        # Create 3 different textured frames
        frames = []
        for pattern_size in [20, 30, 40]:
            frame = np.zeros((240, 320, 3), dtype=np.uint8)
            for i in range(0, 240, pattern_size):
                for j in range(0, 320, pattern_size):
                    if (i // pattern_size + j // pattern_size) % 2 == 0:
                        frame[i : i + pattern_size, j : j + pattern_size] = 255
            frames.append(frame)

        kpts_list, desc_list = compute_orb_descriptors(frames, n_features=500)

        assert len(kpts_list) == 3
        assert len(desc_list) == 3

        # All should have keypoints
        for kpts, desc in zip(kpts_list, desc_list):
            assert len(kpts) > 0
            assert len(desc) > 0
            assert len(kpts) == len(desc)


class TestSaveReferenceIndex:
    """Tests for save_reference_index function."""

    def test_inconsistent_lengths(self):
        """Test error handling for inconsistent input lengths."""
        t_ref = np.array([0.0, 1.0])
        kpts_list = [np.zeros((10, 5), dtype=np.float32)]
        desc_list = [np.zeros((10, 32), dtype=np.uint8)]
        meta = {"test": "value"}

        with pytest.raises(ValueError, match="Inconsistent lengths"):
            save_reference_index("dummy.npz", t_ref, kpts_list, desc_list, meta)

    def test_round_trip_save_load(self):
        """Test saving and loading an index file."""
        with tempfile.TemporaryDirectory() as tmpdir:
            out_path = Path(tmpdir) / "test_index.npz"

            # Create test data
            t_ref = np.array([0.0, 0.5, 1.0], dtype=np.float64)

            kpts_list = [
                np.array(
                    [[10.0, 20.0, 5.0, 45.0, 0.8], [30.0, 40.0, 6.0, 90.0, 0.9]],
                    dtype=np.float32,
                ),
                np.array([[15.0, 25.0, 5.5, 60.0, 0.7]], dtype=np.float32),
                np.zeros((0, 5), dtype=np.float32),  # Empty keypoints
            ]

            desc_list = [
                np.random.randint(0, 256, (2, 32), dtype=np.uint8),
                np.random.randint(0, 256, (1, 32), dtype=np.uint8),
                np.zeros((0, 32), dtype=np.uint8),  # Empty descriptors
            ]

            meta = {
                "source_video": "test.mp4",
                "fps": 3.0,
                "image_max_dim": 960,
                "orb_n_features": 1500,
                "use_clahe": True,
            }

            # Save
            save_reference_index(str(out_path), t_ref, kpts_list, desc_list, meta)

            # Verify file exists
            assert out_path.exists()

            # Load
            data = np.load(str(out_path), allow_pickle=True)

            # Check timestamps
            assert np.array_equal(data["t_ref"], t_ref)

            # Check keypoints
            loaded_kpts = data["kpts"]
            assert len(loaded_kpts) == len(kpts_list)
            for i, (orig, loaded) in enumerate(zip(kpts_list, loaded_kpts)):
                assert np.array_equal(orig, loaded), f"Keypoints mismatch at index {i}"

            # Check descriptors
            loaded_desc = data["desc"]
            assert len(loaded_desc) == len(desc_list)
            for i, (orig, loaded) in enumerate(zip(desc_list, loaded_desc)):
                assert np.array_equal(orig, loaded), f"Descriptors mismatch at index {i}"

            # Check metadata
            meta_json = str(data["meta"])
            loaded_meta = json.loads(meta_json)

            # Check user metadata was preserved
            assert loaded_meta["source_video"] == meta["source_video"]
            assert loaded_meta["fps"] == meta["fps"]
            assert loaded_meta["image_max_dim"] == meta["image_max_dim"]
            assert loaded_meta["orb_n_features"] == meta["orb_n_features"]
            assert loaded_meta["use_clahe"] == meta["use_clahe"]

            # Check auto-added metadata
            assert "version" in loaded_meta
            assert "created_utc" in loaded_meta
            assert "opencv_version" in loaded_meta

    def test_metadata_auto_fields(self):
        """Test that required metadata fields are automatically added."""
        with tempfile.TemporaryDirectory() as tmpdir:
            out_path = Path(tmpdir) / "test_index.npz"

            t_ref = np.array([0.0], dtype=np.float64)
            kpts_list = [np.zeros((0, 5), dtype=np.float32)]
            desc_list = [np.zeros((0, 32), dtype=np.uint8)]
            meta = {"custom_field": "value"}

            save_reference_index(str(out_path), t_ref, kpts_list, desc_list, meta)

            data = np.load(str(out_path), allow_pickle=True)
            loaded_meta = json.loads(str(data["meta"]))

            # Check required fields were added
            assert "version" in loaded_meta
            assert "created_utc" in loaded_meta
            assert "opencv_version" in loaded_meta
            assert "custom_field" in loaded_meta
            assert loaded_meta["custom_field"] == "value"

    def test_empty_keypoints_save_load(self):
        """Test saving and loading with all empty keypoints."""
        with tempfile.TemporaryDirectory() as tmpdir:
            out_path = Path(tmpdir) / "test_empty.npz"

            t_ref = np.array([0.0, 1.0], dtype=np.float64)
            kpts_list = [
                np.zeros((0, 5), dtype=np.float32),
                np.zeros((0, 5), dtype=np.float32),
            ]
            desc_list = [
                np.zeros((0, 32), dtype=np.uint8),
                np.zeros((0, 32), dtype=np.uint8),
            ]
            meta = {"test": "empty"}

            save_reference_index(str(out_path), t_ref, kpts_list, desc_list, meta)

            data = np.load(str(out_path), allow_pickle=True)
            loaded_kpts = data["kpts"]
            loaded_desc = data["desc"]

            assert len(loaded_kpts) == 2
            assert len(loaded_desc) == 2

            for kpts in loaded_kpts:
                assert kpts.shape == (0, 5)

            for desc in loaded_desc:
                assert desc.shape == (0, 32)
