"""Tests for GPS track processing and alignment."""

import tempfile
from datetime import datetime, timedelta, timezone
from pathlib import Path

import gpxpy
import gpxpy.gpx
import numpy as np
import pandas as pd
import pytest

from mtbsync.io.gps import (
    align_distance_curves,
    compute_distance_track,
    estimate_tref_for_tnew,
    haversine_distance,
    parse_gpx,
    resample_track,
    save_gps_alignment_metadata,
    save_gps_pairs,
)


class TestHaversineDistance:
    """Tests for haversine_distance function."""

    def test_same_point(self):
        """Test distance between same point is zero."""
        dist = haversine_distance(40.7128, -74.0060, 40.7128, -74.0060)
        assert dist == pytest.approx(0.0, abs=0.1)

    def test_known_distance(self):
        """Test known distance (NYC to LA is ~3944 km)."""
        # NYC coords
        ny_lat, ny_lon = 40.7128, -74.0060
        # LA coords
        la_lat, la_lon = 34.0522, -118.2437

        dist = haversine_distance(ny_lat, ny_lon, la_lat, la_lon)

        # Should be approximately 3944 km = 3,944,000 m
        assert 3_900_000 < dist < 4_000_000

    def test_short_distance(self):
        """Test short distance (100m approx)."""
        # Start point
        lat1, lon1 = 40.7128, -74.0060
        # Point ~100m north
        lat2, lon2 = 40.7137, -74.0060

        dist = haversine_distance(lat1, lon1, lat2, lon2)

        # Should be approximately 100m
        assert 90 < dist < 110


class TestParseGpx:
    """Tests for parse_gpx function."""

    def create_synthetic_gpx(self, n_points: int = 10, start_time: datetime = None) -> str:
        """Create a synthetic GPX file for testing."""
        gpx = gpxpy.gpx.GPX()
        track = gpxpy.gpx.GPXTrack()
        gpx.tracks.append(track)
        segment = gpxpy.gpx.GPXTrackSegment()
        track.segments.append(segment)

        if start_time is None:
            start_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)

        # Create a straight north track
        for i in range(n_points):
            lat = 40.7128 + i * 0.001  # Move north
            lon = -74.0060
            ele = 100.0 + i * 5.0
            time = start_time + timedelta(seconds=i)

            point = gpxpy.gpx.GPXTrackPoint(lat, lon, elevation=ele, time=time)
            segment.points.append(point)

        return gpx.to_xml()

    def test_parse_valid_gpx(self):
        """Test parsing a valid GPX file."""
        with tempfile.TemporaryDirectory() as tmpdir:
            gpx_path = Path(tmpdir) / "test.gpx"
            gpx_content = self.create_synthetic_gpx(n_points=10)

            with open(gpx_path, "w") as f:
                f.write(gpx_content)

            df = parse_gpx(str(gpx_path))

            assert isinstance(df, pd.DataFrame)
            assert len(df) == 10
            assert "t_utc" in df.columns
            assert "lat" in df.columns
            assert "lon" in df.columns
            assert "ele" in df.columns
            assert "speed_mps" in df.columns

            # Check speed is derived (first point is 0)
            assert df["speed_mps"].iloc[0] == 0.0
            assert df["speed_mps"].iloc[1] > 0.0

    def test_parse_nonexistent_file(self):
        """Test error handling for nonexistent file."""
        with pytest.raises(RuntimeError, match="Failed to parse GPX file"):
            parse_gpx("nonexistent.gpx")

    def test_sorted_by_time(self):
        """Test that points are sorted by time."""
        with tempfile.TemporaryDirectory() as tmpdir:
            gpx_path = Path(tmpdir) / "test.gpx"
            gpx_content = self.create_synthetic_gpx(n_points=5)

            with open(gpx_path, "w") as f:
                f.write(gpx_content)

            df = parse_gpx(str(gpx_path))

            # Check monotonic increasing
            assert all(df["t_utc"].iloc[i] < df["t_utc"].iloc[i + 1] for i in range(len(df) - 1))


class TestComputeDistanceTrack:
    """Tests for compute_distance_track function."""

    def test_straight_north_track(self):
        """Test distance computation for straight north track."""
        # Create simple track moving north
        df = pd.DataFrame(
            {
                "lat": [40.0, 40.001, 40.002, 40.003],
                "lon": [-74.0, -74.0, -74.0, -74.0],
            }
        )

        result = compute_distance_track(df)

        assert "dist_m" in result.columns
        assert result["dist_m"].iloc[0] == 0.0

        # Each 0.001 degree latitude is ~111m
        # So total distance should be ~333m
        total_dist = result["dist_m"].iloc[-1]
        assert 300 < total_dist < 360

    def test_zero_movement(self):
        """Test distance for stationary points."""
        df = pd.DataFrame(
            {
                "lat": [40.0, 40.0, 40.0],
                "lon": [-74.0, -74.0, -74.0],
            }
        )

        result = compute_distance_track(df)

        # All distances should be zero
        assert all(result["dist_m"] == 0.0)


class TestResampleTrack:
    """Tests for resample_track function."""

    def test_resample_uniform(self):
        """Test resampling to uniform time grid."""
        # Create track with irregular timestamps
        df = pd.DataFrame(
            {
                "t_utc": [0.0, 1.0, 3.0, 6.0, 10.0],
                "lat": [40.0, 40.001, 40.002, 40.003, 40.004],
                "lon": [-74.0, -74.0, -74.0, -74.0, -74.0],
                "ele": [100.0, 105.0, 110.0, 115.0, 120.0],
                "dist_m": [0.0, 111.0, 222.0, 333.0, 444.0],
                "speed_mps": [0.0, 5.0, 5.0, 5.0, 5.0],
            }
        )

        resampled = resample_track(df, hz=2.0)  # 2 samples per second

        assert "t_rel" in resampled.columns
        assert resampled["t_rel"].iloc[0] == 0.0

        # Check uniform spacing
        dt = resampled["t_rel"].iloc[1] - resampled["t_rel"].iloc[0]
        assert dt == pytest.approx(0.5, abs=0.01)  # 1/2 Hz = 0.5s

        # Check monotonic
        assert all(
            resampled["t_rel"].iloc[i] < resampled["t_rel"].iloc[i + 1]
            for i in range(len(resampled) - 1)
        )

    def test_resample_preserves_range(self):
        """Test that resampling preserves value ranges."""
        df = pd.DataFrame(
            {
                "t_utc": [0.0, 10.0],
                "lat": [40.0, 40.01],
                "lon": [-74.0, -74.0],
                "ele": [100.0, 200.0],
                "dist_m": [0.0, 1000.0],
                "speed_mps": [0.0, 10.0],
            }
        )

        resampled = resample_track(df, hz=1.0)

        # Check that interpolated values are within original range
        assert resampled["lat"].min() >= df["lat"].min()
        assert resampled["lat"].max() <= df["lat"].max()
        assert resampled["dist_m"].min() >= df["dist_m"].min()
        assert resampled["dist_m"].max() <= df["dist_m"].max()

    def test_too_few_points_raises_error(self):
        """Test error when too few points."""
        df = pd.DataFrame({"t_utc": [0.0], "lat": [40.0], "lon": [-74.0], "dist_m": [0.0], "speed_mps": [0.0], "ele": [100.0]})

        with pytest.raises(ValueError, match="at least 2 points"):
            resample_track(df)


class TestAlignDistanceCurves:
    """Tests for align_distance_curves function."""

    def test_synthetic_offset_recovery(self):
        """Test recovery of known time offset."""
        # Create reference track
        t_rel = np.linspace(0, 30, 100)
        speed = 5.0 + 2.0 * np.sin(2 * np.pi * t_rel / 10)  # Varying speed

        ref_df = pd.DataFrame(
            {
                "t_rel": t_rel,
                "speed_mps": speed,
                "dist_m": np.cumsum(speed * np.gradient(t_rel)),
            }
        )

        # Create new track with +3.2s offset (new starts 3.2s after ref)
        known_offset = 3.2
        new_df = pd.DataFrame(
            {
                "t_rel": t_rel,  # Same relative time grid
                "speed_mps": np.interp(t_rel + known_offset, t_rel, speed, left=0, right=0),
                "dist_m": np.cumsum(
                    np.interp(t_rel + known_offset, t_rel, speed, left=0, right=0)
                    * np.gradient(t_rel)
                ),
            }
        )

        result = align_distance_curves(ref_df, new_df, max_offset_sec=10.0)

        assert "offset_sec" in result
        assert "corr_peak" in result
        assert "method" in result

        # Should recover offset within ±1.0s (cross-correlation is discrete)
        assert abs(result["offset_sec"] - known_offset) < 1.0

        # Should have strong correlation
        assert result["corr_peak"] > 0.5

    def test_no_correlation_returns_zero(self):
        """Test that uncorrelated signals return near-zero correlation."""
        # Completely different speed profiles
        ref_df = pd.DataFrame(
            {
                "t_rel": np.linspace(0, 10, 50),
                "speed_mps": np.random.rand(50) * 5,
                "dist_m": np.cumsum(np.random.rand(50) * 5),
            }
        )

        new_df = pd.DataFrame(
            {
                "t_rel": np.linspace(0, 10, 50),
                "speed_mps": np.random.rand(50) * 5,
                "dist_m": np.cumsum(np.random.rand(50) * 5),
            }
        )

        result = align_distance_curves(ref_df, new_df)

        # Correlation should be weak for random signals
        assert result["corr_peak"] < 0.8  # Not strong

    def test_insufficient_points(self):
        """Test handling of tracks with too few points."""
        ref_df = pd.DataFrame({"t_rel": [0.0, 1.0], "speed_mps": [5.0, 5.0], "dist_m": [0.0, 5.0]})
        new_df = pd.DataFrame({"t_rel": [0.0, 1.0], "speed_mps": [5.0, 5.0], "dist_m": [0.0, 5.0]})

        result = align_distance_curves(ref_df, new_df)

        # Should return defaults
        assert result["offset_sec"] == 0.0
        assert result["corr_peak"] == 0.0


class TestEstimateTrefForTnew:
    """Tests for estimate_tref_for_tnew function."""

    def test_linear_distance_match(self):
        """Test estimation with linear distance progression."""
        # Create synthetic tracks with constant speed
        ref_df = pd.DataFrame(
            {
                "t_rel": np.linspace(0, 10, 50),
                "dist_m": np.linspace(0, 100, 50),  # 10 m/s
            }
        )

        new_df = pd.DataFrame(
            {
                "t_rel": np.linspace(0, 10, 50),
                "dist_m": np.linspace(0, 100, 50),  # Same speed
            }
        )

        # New video timestamps
        tnew_s = np.array([0.0, 2.5, 5.0, 7.5, 10.0])

        # No offset
        tref_est = estimate_tref_for_tnew(tnew_s, ref_df, new_df, offset_sec=0.0)

        # Should match exactly (within tolerance)
        assert len(tref_est) == len(tnew_s)
        assert np.allclose(tref_est, tnew_s, atol=0.5)

    def test_with_offset(self):
        """Test estimation with time offset."""
        # Tracks with identical distance progression
        ref_df = pd.DataFrame(
            {
                "t_rel": np.linspace(0, 20, 100),
                "dist_m": np.linspace(0, 200, 100),
            }
        )

        new_df = pd.DataFrame(
            {
                "t_rel": np.linspace(0, 20, 100),
                "dist_m": np.linspace(0, 200, 100),
            }
        )

        tnew_s = np.array([5.0, 10.0, 15.0])
        known_offset = -3.0  # new is ahead by 3s (adjust with negative offset)

        tref_est = estimate_tref_for_tnew(tnew_s, ref_df, new_df, offset_sec=known_offset)

        # With same distance progression and negative offset:
        # t_new=5 → look up at 5-(-3)=8s in new → dist=80m → ref at 80m is t=8s
        expected = tnew_s - known_offset  # 5 - (-3) = 8
        assert np.allclose(tref_est, expected, atol=0.5)

    def test_extrapolation_bounds(self):
        """Test that estimates are bounded by track limits."""
        ref_df = pd.DataFrame(
            {
                "t_rel": np.array([0.0, 10.0]),
                "dist_m": np.array([0.0, 100.0]),
            }
        )

        new_df = pd.DataFrame(
            {
                "t_rel": np.array([0.0, 10.0]),
                "dist_m": np.array([0.0, 100.0]),
            }
        )

        # Request timestamps outside track range
        tnew_s = np.array([-5.0, 5.0, 15.0])

        tref_est = estimate_tref_for_tnew(tnew_s, ref_df, new_df, offset_sec=0.0)

        # First should be clamped to start
        assert tref_est[0] == pytest.approx(0.0, abs=0.1)

        # Middle should be reasonable
        assert 3.0 < tref_est[1] < 7.0

        # Last should be clamped to end
        assert tref_est[2] == pytest.approx(10.0, abs=0.1)


class TestSaveGpsMetadata:
    """Tests for save_gps_alignment_metadata function."""

    def test_save_and_load_metadata(self):
        """Test saving GPS alignment metadata."""
        with tempfile.TemporaryDirectory() as tmpdir:
            meta_path = Path(tmpdir) / "gps_alignment.json"

            alignment_result = {
                "offset_sec": 3.2,
                "corr_peak": 0.85,
                "method": "xcorr",
            }

            ref_df = pd.DataFrame({"t_rel": np.linspace(0, 30, 100), "dist_m": np.linspace(0, 300, 100)})
            new_df = pd.DataFrame({"t_rel": np.linspace(0, 28, 95), "dist_m": np.linspace(0, 280, 95)})

            save_gps_alignment_metadata(
                str(meta_path), alignment_result, ref_df, new_df, resample_hz=10.0
            )

            assert meta_path.exists()

            # Load and verify
            import json

            with open(meta_path, "r") as f:
                loaded = json.load(f)

            assert "offset_sec" in loaded
            assert "corr_peak" in loaded
            assert "resample_hz" in loaded
            assert "ref_len_s" in loaded
            assert "new_len_s" in loaded

            assert loaded["offset_sec"] == 3.2
            assert loaded["corr_peak"] == 0.85
            assert loaded["resample_hz"] == 10.0


class TestSaveGpsPairs:
    """Tests for save_gps_pairs function."""

    def test_save_and_load_pairs(self):
        """Test saving GPS pairs."""
        with tempfile.TemporaryDirectory() as tmpdir:
            pairs_path = Path(tmpdir) / "pairs_gps.csv"

            tnew_s = np.array([0.0, 1.0, 2.0, 3.0])
            tref_est = np.array([3.2, 4.2, 5.2, 6.2])

            save_gps_pairs(str(pairs_path), tnew_s, tref_est)

            assert pairs_path.exists()

            # Load and verify
            df = pd.read_csv(pairs_path)

            assert len(df) == 4
            assert "t_new" in df.columns
            assert "t_ref_est" in df.columns

            assert np.allclose(df["t_new"].values, tnew_s)
            assert np.allclose(df["t_ref_est"].values, tref_est)
