"""Tests for local refinement with RANSAC."""

import tempfile
from pathlib import Path

import cv2
import numpy as np
import pandas as pd
import pytest

from mtbsync.match.local import (
    compute_affine_ransac,
    compute_homography_ransac,
    match_local_features,
    refine_pairs_locally,
    refine_single_candidate,
    save_refined_pairs,
    sigmoid,
)


class TestSigmoid:
    """Tests for sigmoid function."""

    def test_sigmoid_values(self):
        """Test sigmoid function values."""
        assert sigmoid(0.0) == pytest.approx(0.5, abs=0.01)
        assert sigmoid(10.0) > 0.99
        assert sigmoid(-10.0) < 0.01
        assert 0.0 < sigmoid(1.0) < 1.0


class TestMatchLocalFeatures:
    """Tests for match_local_features function."""

    def test_empty_descriptors(self):
        """Test handling of empty descriptors."""
        kpts_new = np.zeros((0, 5), dtype=np.float32)
        desc_new = np.zeros((0, 32), dtype=np.uint8)
        kpts_ref = np.random.rand(10, 5).astype(np.float32)
        desc_ref = np.random.randint(0, 256, (10, 32), dtype=np.uint8)

        pts_new, pts_ref = match_local_features(kpts_new, desc_new, kpts_ref, desc_ref)

        assert len(pts_new) == 0
        assert len(pts_ref) == 0

    def test_identical_descriptors(self):
        """Test that identical descriptors produce matches."""
        # Create identical keypoints and descriptors
        kpts = np.random.rand(20, 5).astype(np.float32)
        desc = np.random.randint(0, 256, (20, 32), dtype=np.uint8)

        pts_new, pts_ref = match_local_features(kpts, desc, kpts, desc)

        # Should have many matches
        assert len(pts_new) > 0
        assert len(pts_new) == len(pts_ref)


class TestComputeHomographyRansac:
    """Tests for compute_homography_ransac function."""

    def test_insufficient_points(self):
        """Test with insufficient points (< 4)."""
        pts_new = np.random.rand(3, 2).astype(np.float32)
        pts_ref = np.random.rand(3, 2).astype(np.float32)

        H, mask, error = compute_homography_ransac(pts_new, pts_ref)

        assert H is None
        assert np.sum(mask) == 0
        assert error == np.inf

    def test_perfect_homography(self):
        """Test with points generated from a known homography."""
        # Create a known homography (rotation + translation)
        angle = np.radians(15)
        H_true = np.array(
            [
                [np.cos(angle), -np.sin(angle), 50],
                [np.sin(angle), np.cos(angle), 30],
                [0, 0, 1],
            ],
            dtype=np.float32,
        )

        # Generate random points
        pts_new = np.random.rand(50, 2).astype(np.float32) * 500

        # Transform points using homography
        pts_new_homog = np.hstack([pts_new, np.ones((50, 1))])
        pts_ref_homog = (H_true @ pts_new_homog.T).T
        pts_ref = pts_ref_homog[:, :2] / pts_ref_homog[:, 2:3]

        H, mask, error = compute_homography_ransac(pts_new, pts_ref, ransac_thresh=1.0)

        assert H is not None
        assert np.sum(mask) >= 40  # Most should be inliers
        assert error < 0.5  # Very low error for perfect transform


class TestComputeAffineRansac:
    """Tests for compute_affine_ransac function."""

    def test_insufficient_points(self):
        """Test with insufficient points (< 3)."""
        pts_new = np.random.rand(2, 2).astype(np.float32)
        pts_ref = np.random.rand(2, 2).astype(np.float32)

        A, mask, error = compute_affine_ransac(pts_new, pts_ref)

        assert A is None
        assert np.sum(mask) == 0
        assert error == np.inf

    def test_perfect_affine(self):
        """Test with points generated from a known affine transform."""
        # Create a known affine (rotation + scale + translation)
        angle = np.radians(20)
        scale = 1.2
        A_true = np.array(
            [
                [scale * np.cos(angle), -scale * np.sin(angle), 40],
                [scale * np.sin(angle), scale * np.cos(angle), 25],
            ],
            dtype=np.float32,
        )

        # Generate random points
        pts_new = np.random.rand(40, 2).astype(np.float32) * 400

        # Transform points using affine
        pts_new_homog = np.hstack([pts_new, np.ones((40, 1))])
        pts_ref = (A_true @ pts_new_homog.T).T

        A, mask, error = compute_affine_ransac(pts_new, pts_ref, ransac_thresh=1.0)

        assert A is not None
        assert np.sum(mask) >= 30  # Most should be inliers
        assert error < 0.5  # Very low error for perfect transform


class TestRefineSingleCandidate:
    """Tests for refine_single_candidate function."""

    def test_synthetic_homography_recovery(self):
        """Test that we can recover a synthetic homography with high confidence."""
        # Create checkerboard pattern for ORB detection
        frame1 = np.zeros((480, 640), dtype=np.uint8)
        square_size = 40
        for i in range(0, 480, square_size):
            for j in range(0, 640, square_size):
                if (i // square_size + j // square_size) % 2 == 0:
                    frame1[i : i + square_size, j : j + square_size] = 255

        # Create a homography (rotation + translation)
        angle = np.radians(10)
        H_true = np.array(
            [
                [np.cos(angle), -np.sin(angle), 30],
                [np.sin(angle), np.cos(angle), 20],
                [0, 0, 1],
            ],
            dtype=np.float32,
        )

        # Warp frame1 to create frame2
        frame2 = cv2.warpPerspective(frame1, H_true, (640, 480))

        # Detect ORB features
        orb = cv2.ORB_create(nfeatures=500)
        kp1, desc1 = orb.detectAndCompute(frame1, None)
        kp2, desc2 = orb.detectAndCompute(frame2, None)

        # Convert keypoints to arrays
        kpts1 = np.array([[kp.pt[0], kp.pt[1], kp.size, kp.angle, kp.response] for kp in kp1], dtype=np.float32)
        kpts2 = np.array([[kp.pt[0], kp.pt[1], kp.size, kp.angle, kp.response] for kp in kp2], dtype=np.float32)

        # Refine
        confidence, n_inliers, reproj_error, model = refine_single_candidate(
            kpts1, desc1, kpts2, desc2, min_inliers=10
        )

        # Should recover valid transform
        assert model in ["H", "A"]
        assert n_inliers >= 10
        assert reproj_error < 5.0
        assert confidence > 0.5  # High confidence

    def test_no_features(self):
        """Test with no features (blank images)."""
        kpts_new = np.zeros((0, 5), dtype=np.float32)
        desc_new = np.zeros((0, 32), dtype=np.uint8)
        kpts_ref = np.zeros((0, 5), dtype=np.float32)
        desc_ref = np.zeros((0, 32), dtype=np.uint8)

        confidence, n_inliers, reproj_error, model = refine_single_candidate(
            kpts_new, desc_new, kpts_ref, desc_ref
        )

        assert confidence == 0.0
        assert n_inliers == 0
        assert reproj_error == np.inf
        assert model == "NONE"


class TestRefinePairsLocally:
    """Tests for refine_pairs_locally function."""

    def test_basic_refinement(self):
        """Test basic refinement with synthetic data."""
        # Create checkerboard for features
        def make_checkerboard():
            frame = np.zeros((240, 320), dtype=np.uint8)
            for i in range(0, 240, 30):
                for j in range(0, 320, 30):
                    if (i // 30 + j // 30) % 2 == 0:
                        frame[i : i + 30, j : j + 30] = 255
            return frame

        # Create frames with slight rotation
        frame1 = make_checkerboard()
        M = cv2.getRotationMatrix2D((160, 120), 5, 1.0)
        frame2 = cv2.warpAffine(frame1, M, (320, 240))

        # Detect features
        orb = cv2.ORB_create(nfeatures=300)
        kp1, desc1 = orb.detectAndCompute(frame1, None)
        kp2, desc2 = orb.detectAndCompute(frame2, None)

        kpts1 = np.array([[kp.pt[0], kp.pt[1], kp.size, kp.angle, kp.response] for kp in kp1], dtype=np.float32)
        kpts2 = np.array([[kp.pt[0], kp.pt[1], kp.size, kp.angle, kp.response] for kp in kp2], dtype=np.float32)

        # Create coarse pairs
        pairs_df = pd.DataFrame(
            {
                "t_new": [0.0, 0.0],
                "t_ref": [0.0, 0.333],
                "score": [0.8, 0.7],
                "n_matches": [50, 40],
            }
        )

        new_timestamps = np.array([0.0])
        new_kpts_list = [kpts1]
        new_desc_list = [desc1]

        ref_timestamps = np.array([0.0, 0.333, 0.666])
        ref_kpts_list = [kpts2, kpts2, kpts2]
        ref_desc_list = [desc2, desc2, desc2]

        # Refine with lenient parameters for synthetic data
        refined_df = refine_pairs_locally(
            pairs_df,
            new_timestamps,
            new_kpts_list,
            new_desc_list,
            ref_timestamps,
            ref_kpts_list,
            ref_desc_list,
            fps=3.0,
            min_inliers=5,  # Lower threshold for synthetic checkerboard
            ransac_thresh=5.0,  # More permissive for rotation artifacts
        )

        # Should have refined results
        assert isinstance(refined_df, pd.DataFrame)
        assert len(refined_df) > 0

        # Check columns
        assert "t_new" in refined_df.columns
        assert "t_ref" in refined_df.columns
        assert "confidence" in refined_df.columns
        assert "inliers" in refined_df.columns
        assert "reproj_error" in refined_df.columns
        assert "model" in refined_df.columns

    def test_monotonic_t_new(self):
        """Test that output is sorted by monotonic t_new."""
        # Create simple synthetic data with slight transformation
        def make_checkerboard():
            frame = np.zeros((200, 200), dtype=np.uint8)
            for i in range(0, 200, 25):
                for j in range(0, 200, 25):
                    if (i // 25 + j // 25) % 2 == 0:
                        frame[i : i + 25, j : j + 25] = 255
            return frame

        frame1 = make_checkerboard()
        # Apply small rotation to create a second frame
        M = cv2.getRotationMatrix2D((100, 100), 3, 1.0)
        frame2 = cv2.warpAffine(frame1, M, (200, 200))

        orb = cv2.ORB_create(nfeatures=200)
        kp1, desc1 = orb.detectAndCompute(frame1, None)
        kp2, desc2 = orb.detectAndCompute(frame2, None)

        kpts1 = np.array([[k.pt[0], k.pt[1], k.size, k.angle, k.response] for k in kp1], dtype=np.float32)
        kpts2 = np.array([[k.pt[0], k.pt[1], k.size, k.angle, k.response] for k in kp2], dtype=np.float32)

        # Create coarse pairs (intentionally out of order)
        pairs_df = pd.DataFrame(
            {
                "t_new": [1.0, 1.0, 0.0, 0.0, 2.0],
                "t_ref": [0.0, 1.0, 0.0, 1.0, 1.0],
                "score": [0.8, 0.7, 0.9, 0.6, 0.75],
                "n_matches": [30, 25, 35, 20, 28],
            }
        )

        new_timestamps = np.array([0.0, 1.0, 2.0])
        new_kpts_list = [kpts1.copy(), kpts1.copy(), kpts1.copy()]
        new_desc_list = [desc1.copy(), desc1.copy(), desc1.copy()]

        ref_timestamps = np.array([0.0, 1.0, 2.0])
        ref_kpts_list = [kpts2.copy(), kpts2.copy(), kpts2.copy()]
        ref_desc_list = [desc2.copy(), desc2.copy(), desc2.copy()]

        refined_df = refine_pairs_locally(
            pairs_df,
            new_timestamps,
            new_kpts_list,
            new_desc_list,
            ref_timestamps,
            ref_kpts_list,
            ref_desc_list,
            fps=1.0,
            min_inliers=5,
            ransac_thresh=5.0,  # More permissive threshold
        )

        # Check monotonicity
        t_new_values = refined_df["t_new"].values
        assert all(t_new_values[i] <= t_new_values[i + 1] for i in range(len(t_new_values) - 1))

    def test_empty_pairs_raises_error(self):
        """Test that empty pairs DataFrame raises error."""
        pairs_df = pd.DataFrame()

        with pytest.raises(ValueError, match="Input pairs_df is empty"):
            refine_pairs_locally(
                pairs_df,
                np.array([0.0]),
                [np.zeros((0, 5), dtype=np.float32)],
                [np.zeros((0, 32), dtype=np.uint8)],
                np.array([0.0]),
                [np.zeros((0, 5), dtype=np.float32)],
                [np.zeros((0, 32), dtype=np.uint8)],
            )

    def test_no_valid_refinements_raises_error(self):
        """Test that no valid refinements raises error."""
        # Create pairs with empty descriptors (will fail refinement)
        pairs_df = pd.DataFrame(
            {
                "t_new": [0.0],
                "t_ref": [0.0],
                "score": [0.5],
                "n_matches": [0],
            }
        )

        new_timestamps = np.array([0.0])
        new_kpts_list = [np.zeros((0, 5), dtype=np.float32)]
        new_desc_list = [np.zeros((0, 32), dtype=np.uint8)]

        ref_timestamps = np.array([0.0])
        ref_kpts_list = [np.zeros((0, 5), dtype=np.float32)]
        ref_desc_list = [np.zeros((0, 32), dtype=np.uint8)]

        with pytest.raises(RuntimeError, match="No valid refined pairs found"):
            refine_pairs_locally(
                pairs_df,
                new_timestamps,
                new_kpts_list,
                new_desc_list,
                ref_timestamps,
                ref_kpts_list,
                ref_desc_list,
            )


class TestSaveRefinedPairs:
    """Tests for save_refined_pairs function."""

    def test_save_and_load_csv(self):
        """Test saving and loading refined pairs CSV."""
        with tempfile.TemporaryDirectory() as tmpdir:
            csv_path = Path(tmpdir) / "refined_pairs.csv"

            # Create test DataFrame
            df = pd.DataFrame(
                {
                    "t_new": [0.0, 1.0, 2.0],
                    "t_ref": [0.0, 1.0, 2.0],
                    "confidence": [0.95, 0.87, 0.92],
                    "inliers": [45, 38, 42],
                    "reproj_error": [0.8, 1.2, 0.9],
                    "model": ["H", "H", "A"],
                }
            )

            # Save
            save_refined_pairs(df, str(csv_path))

            # Load and verify
            loaded_df = pd.read_csv(csv_path)

            assert len(loaded_df) == 3
            assert "t_new" in loaded_df.columns
            assert "t_ref" in loaded_df.columns
            assert "confidence" in loaded_df.columns
            assert "inliers" in loaded_df.columns
            assert "reproj_error" in loaded_df.columns
            assert "model" in loaded_df.columns

            # Check values are preserved
            assert np.allclose(loaded_df["t_new"], df["t_new"])
            assert np.allclose(loaded_df["confidence"], df["confidence"], rtol=1e-5)
            assert all(loaded_df["model"] == df["model"])
