"""GPS track parsing and alignment using GPX files."""

import json
from datetime import datetime
from math import atan2, cos, radians, sin, sqrt
from typing import Dict, Optional

import gpxpy
import numpy as np
import pandas as pd
from scipy import signal


def haversine_distance(lat1: float, lon1: float, lat2: float, lon2: float) -> float:
    """
    Calculate great-circle distance between two points using Haversine formula.

    Args:
        lat1: Latitude of first point (degrees)
        lon1: Longitude of first point (degrees)
        lat2: Latitude of second point (degrees)
        lon2: Longitude of second point (degrees)

    Returns:
        Distance in meters
    """
    R = 6371000  # Earth radius in meters

    # Convert to radians
    lat1_rad, lon1_rad = radians(lat1), radians(lon1)
    lat2_rad, lon2_rad = radians(lat2), radians(lon2)

    dlat = lat2_rad - lat1_rad
    dlon = lon2_rad - lon1_rad

    a = sin(dlat / 2) ** 2 + cos(lat1_rad) * cos(lat2_rad) * sin(dlon / 2) ** 2
    c = 2 * atan2(sqrt(a), sqrt(1 - a))

    return R * c


def parse_gpx(gpx_path: str) -> pd.DataFrame:
    """
    Parse GPX file and extract track points with timestamps.

    Args:
        gpx_path: Path to GPX file

    Returns:
        DataFrame with columns: t_utc (seconds since epoch), lat, lon, ele, speed_mps
        speed_mps is derived from positions if not present in GPX

    Raises:
        RuntimeError: If GPX file cannot be parsed or has no valid track points
    """
    try:
        with open(gpx_path, "r") as f:
            gpx = gpxpy.parse(f)
    except Exception as e:
        raise RuntimeError(f"Failed to parse GPX file: {gpx_path}") from e

    # Extract track points
    points = []
    for track in gpx.tracks:
        for segment in track.segments:
            for point in segment.points:
                if point.time is not None:
                    points.append(
                        {
                            "t_utc": point.time.timestamp(),
                            "lat": point.latitude,
                            "lon": point.longitude,
                            "ele": point.elevation if point.elevation is not None else np.nan,
                        }
                    )

    if not points:
        raise RuntimeError(f"No valid track points with timestamps found in {gpx_path}")

    df = pd.DataFrame(points)

    # Sort by time
    df = df.sort_values("t_utc").reset_index(drop=True)

    # Derive speed from positions using finite differences
    speed_mps = np.zeros(len(df))
    for i in range(1, len(df)):
        dt = df.loc[i, "t_utc"] - df.loc[i - 1, "t_utc"]
        if dt > 0:
            dist = haversine_distance(
                df.loc[i - 1, "lat"],
                df.loc[i - 1, "lon"],
                df.loc[i, "lat"],
                df.loc[i, "lon"],
            )
            speed_mps[i] = dist / dt

    df["speed_mps"] = speed_mps

    return df


def compute_distance_track(df: pd.DataFrame) -> pd.DataFrame:
    """
    Compute cumulative distance along track using Haversine formula.

    Args:
        df: DataFrame with columns lat, lon (from parse_gpx)

    Returns:
        DataFrame with added column: dist_m (cumulative 2D distance in meters)
    """
    df = df.copy()

    dist_m = np.zeros(len(df))
    for i in range(1, len(df)):
        segment_dist = haversine_distance(
            df.loc[i - 1, "lat"],
            df.loc[i - 1, "lon"],
            df.loc[i, "lat"],
            df.loc[i, "lon"],
        )
        dist_m[i] = dist_m[i - 1] + segment_dist

    df["dist_m"] = dist_m

    return df


def resample_track(df: pd.DataFrame, hz: float = 10.0) -> pd.DataFrame:
    """
    Resample track to uniform time intervals starting at t_rel=0.

    Args:
        df: DataFrame with columns t_utc, dist_m, speed_mps
        hz: Target sampling rate in Hz

    Returns:
        DataFrame with columns: t_rel (seconds from start), lat, lon, ele, dist_m, speed_mps
        All fields interpolated to uniform time grid
    """
    if len(df) < 2:
        raise ValueError("Track must have at least 2 points for resampling")

    # Create relative time starting at 0
    t_rel_orig = df["t_utc"].values - df["t_utc"].iloc[0]

    # Create uniform time grid
    duration = t_rel_orig[-1]
    n_samples = int(duration * hz) + 1
    t_rel_uniform = np.linspace(0, duration, n_samples)

    # Interpolate all fields
    lat_interp = np.interp(t_rel_uniform, t_rel_orig, df["lat"].values)
    lon_interp = np.interp(t_rel_uniform, t_rel_orig, df["lon"].values)
    ele_interp = np.interp(t_rel_uniform, t_rel_orig, df["ele"].fillna(0).values)
    dist_interp = np.interp(t_rel_uniform, t_rel_orig, df["dist_m"].values)
    speed_interp = np.interp(t_rel_uniform, t_rel_orig, df["speed_mps"].values)

    resampled_df = pd.DataFrame(
        {
            "t_rel": t_rel_uniform,
            "lat": lat_interp,
            "lon": lon_interp,
            "ele": ele_interp,
            "dist_m": dist_interp,
            "speed_mps": speed_interp,
        }
    )

    return resampled_df


def align_distance_curves(
    ref_df: pd.DataFrame,
    new_df: pd.DataFrame,
    max_offset_sec: float = 60.0,
) -> Dict:
    """
    Align two GPS tracks using cross-correlation on speed profiles.

    Args:
        ref_df: Reference track DataFrame (with t_rel, speed_mps)
        new_df: New track DataFrame (with t_rel, speed_mps)
        max_offset_sec: Maximum time offset to search (seconds)

    Returns:
        Dictionary with:
        - offset_sec: Estimated time offset (new = ref + offset_sec)
        - corr_peak: Peak correlation value (0-1)
        - method: "xcorr"
    """
    if len(ref_df) < 10 or len(new_df) < 10:
        return {"offset_sec": 0.0, "corr_peak": 0.0, "method": "xcorr"}

    # Extract speed signals
    ref_speed = ref_df["speed_mps"].values
    new_speed = new_df["speed_mps"].values

    # Normalize signals (zero mean, unit variance)
    ref_speed_norm = (ref_speed - np.mean(ref_speed)) / (np.std(ref_speed) + 1e-9)
    new_speed_norm = (new_speed - np.mean(new_speed)) / (np.std(new_speed) + 1e-9)

    # Compute cross-correlation
    correlation = signal.correlate(ref_speed_norm, new_speed_norm, mode="full")
    correlation = correlation / len(new_speed)  # Normalize

    lags = signal.correlation_lags(len(ref_speed), len(new_speed), mode="full")

    # Convert lags to time offset (assuming uniform sampling)
    dt_ref = ref_df["t_rel"].iloc[1] - ref_df["t_rel"].iloc[0] if len(ref_df) > 1 else 0.1
    time_lags = lags * dt_ref

    # Find peak within max_offset_sec range
    valid_mask = np.abs(time_lags) <= max_offset_sec
    if not np.any(valid_mask):
        return {"offset_sec": 0.0, "corr_peak": 0.0, "method": "xcorr"}

    valid_corr = correlation[valid_mask]
    valid_lags = time_lags[valid_mask]

    peak_idx = np.argmax(valid_corr)
    offset_sec = valid_lags[peak_idx]
    corr_peak = valid_corr[peak_idx]

    return {
        "offset_sec": float(offset_sec),
        "corr_peak": float(corr_peak),
        "method": "xcorr",
    }


def _estimate_tref_for_tnew_vectorised(ref_df: pd.DataFrame, new_df: pd.DataFrame, tnew_s: np.ndarray, offset_sec: float = 0.0) -> np.ndarray:
    """
    Vectorised equivalent of estimate_tref_for_tnew() using np.interp.
    Steps:
      1) Shift new times by offset
      2) dist_new = interp(new.t_rel, new.dist_m) for ALL t_new
      3) t_ref_est = interp(ref.dist_m, ref.t_rel) using dist_new
      4) Clamp to ref range endpoints where needed
    """
    if len(tnew_s) == 0:
        return np.array([], dtype=float)

    t_new_gps = tnew_s - offset_sec

    # 1. distance at new times (clamped to bounds via np.clip)
    new_t_rel = new_df["t_rel"].values
    new_dist = new_df["dist_m"].values
    t_new_gps_clamped = np.clip(t_new_gps, new_t_rel[0], new_t_rel[-1])
    dist_new = np.interp(t_new_gps_clamped, new_t_rel, new_dist)

    # 2. map distances to reference times
    ref_dist = ref_df["dist_m"].values
    ref_t_rel = ref_df["t_rel"].values

    # Clamp distances to ref bounds to avoid NaNs
    dist_new_clamped = np.clip(dist_new, ref_dist[0], ref_dist[-1])
    tref_est = np.interp(dist_new_clamped, ref_dist, ref_t_rel)
    return tref_est


def estimate_tref_for_tnew(
    tnew_s: np.ndarray,
    ref_df: pd.DataFrame,
    new_df: pd.DataFrame,
    offset_sec: float,
) -> np.ndarray:
    """
    Estimate reference timestamps for each new timestamp using distance matching.
    Vectorised for large inputs; falls back to scalar loop for tiny arrays.

    For each t_new, finds the distance traveled in the new run, then finds the
    corresponding time in the reference run where the same distance was reached.

    Args:
        tnew_s: Array of new video timestamps (seconds from start of new video)
        ref_df: Reference GPS track (with t_rel, dist_m)
        new_df: New GPS track (with t_rel, dist_m)
        offset_sec: Time offset from align_distance_curves (new = ref + offset)

    Returns:
        Array of estimated reference timestamps (seconds from start of ref video)
    """
    # Fast path for longer sequences
    if tnew_s is not None and len(tnew_s) >= 512:
        return _estimate_tref_for_tnew_vectorised(ref_df, new_df, tnew_s, offset_sec=offset_sec)

    # Original scalar path (unchanged) for small inputs
    tref_est = np.zeros(len(tnew_s))

    for i, t_new in enumerate(tnew_s):
        # Adjust t_new by offset to get GPS time
        t_new_gps = t_new - offset_sec

        # Find distance at t_new in new track
        if t_new_gps < new_df["t_rel"].iloc[0]:
            # Before track starts
            dist_new = new_df["dist_m"].iloc[0]
        elif t_new_gps > new_df["t_rel"].iloc[-1]:
            # After track ends
            dist_new = new_df["dist_m"].iloc[-1]
        else:
            # Interpolate distance
            dist_new = np.interp(t_new_gps, new_df["t_rel"].values, new_df["dist_m"].values)

        # Find corresponding time in reference track where dist_m matches
        if dist_new <= ref_df["dist_m"].iloc[0]:
            tref_est[i] = ref_df["t_rel"].iloc[0]
        elif dist_new >= ref_df["dist_m"].iloc[-1]:
            tref_est[i] = ref_df["t_rel"].iloc[-1]
        else:
            # Interpolate time from distance
            tref_est[i] = np.interp(dist_new, ref_df["dist_m"].values, ref_df["t_rel"].values)

    return tref_est


def save_gps_alignment_metadata(output_path: str, alignment_result: Dict, ref_df: pd.DataFrame, new_df: pd.DataFrame, resample_hz: float) -> None:
    """
    Save GPS alignment metadata to JSON file.

    Args:
        output_path: Path to output JSON file
        alignment_result: Result from align_distance_curves
        ref_df: Reference GPS track
        new_df: New GPS track
        resample_hz: Resampling rate used
    """
    metadata = {
        "offset_sec": alignment_result["offset_sec"],
        "corr_peak": alignment_result["corr_peak"],
        "method": alignment_result["method"],
        "resample_hz": resample_hz,
        "ref_len_s": float(ref_df["t_rel"].iloc[-1] - ref_df["t_rel"].iloc[0]),
        "new_len_s": float(new_df["t_rel"].iloc[-1] - new_df["t_rel"].iloc[0]),
        "ref_dist_m": float(ref_df["dist_m"].iloc[-1]),
        "new_dist_m": float(new_df["dist_m"].iloc[-1]),
    }

    with open(output_path, "w") as f:
        json.dump(metadata, f, indent=2)


def save_gps_pairs(output_path: str, tnew_s: np.ndarray, tref_est: np.ndarray) -> None:
    """
    Save GPS-estimated time pairs to CSV for debugging.

    Args:
        output_path: Path to output CSV file
        tnew_s: New video timestamps
        tref_est: Estimated reference timestamps
    """
    df = pd.DataFrame({"t_new": tnew_s, "t_ref_est": tref_est})
    df.to_csv(output_path, index=False, float_format="%.6f")
