"""Time-warp fitting and gating utilities for MTB sync.

This module provides robust affine time-warp fitting using RANSAC + IRLS refinement.
Designed to handle noisy temporal correspondences from visual/GPS alignment.

Example:
    >>> pairs = [(0.0, 0.1), (1.0, 1.05), (2.0, 2.15), (3.0, 3.1)]
    >>> result = fit_timewarp(pairs, rng_seed=42)
    >>> result['ok']
    True
    >>> warp = result['warp']
    >>> candidates = [(1.0, 1.0), (1.05, 1.0), (5.0, 5.0)]
    >>> filtered = gate_by_warp(candidates, warp, window_sec=0.2)
    >>> len(filtered)
    1
"""

from __future__ import annotations

import random
from typing import Any, Dict, List, Sequence, Tuple


class TimeWarp:
    """Affine time-warp model mapping t_new -> t_ref.

    Model: t_ref = a * t_new + b

    Attributes:
        model: Model type (currently only "affine" supported).
        a: Scale factor (slope).
        b: Offset (intercept).
    """

    def __init__(self, model: str, a: float, b: float) -> None:
        """Initialise TimeWarp model."""
        self.model = model
        self.a = a
        self.b = b

    def f(self, t_new: float) -> float:
        """Map t_new -> t_ref_hat using affine model."""
        return self.a * t_new + self.b


def fit_timewarp(
    pairs: Sequence[Tuple[float, float]],
    model: str = "affine",
    inlier_thresh: float = 0.08,
    ransac_iters: int = 500,
    max_ppm: float = 1000,
    min_inlier_frac: float = 0.25,
    rng_seed: int | None = 0,
) -> Dict[str, Any]:
    """Fit robust affine time-warp using RANSAC + IRLS refinement.

    Args:
        pairs: Sequence of (t_new, t_ref) time correspondences. Note the order.
        model: Model type (currently only "affine" supported).
        inlier_thresh: RANSAC inlier threshold in seconds (L1 residual).
        ransac_iters: Number of RANSAC iterations.
        max_ppm: Maximum allowed parts-per-million drift from 1:1 speed.
        min_inlier_frac: Minimum required fraction of inliers (RANSAC).
        rng_seed: Random seed for deterministic results. None = do not reseed.

    Returns:
        Dict with keys:
            - warp: TimeWarp instance
            - model: Model type string
            - ok: True if quality gates passed (ppm and inlier_frac)
            - params: {"a", "b}
            - ppm: abs(a-1)*1e6
            - residuals: {"mae", "p90"}
            - inlier_frac: Fraction of inliers from RANSAC
            - n_pairs: Number of input pairs
    """
    # Graceful handling for insufficient data
    if not pairs or len(pairs) < 2:
        warp = TimeWarp(model="affine", a=1.0, b=0.0)
        return {
            "warp": warp,
            "model": "affine",
            "ok": False,
            "params": {"a": 1.0, "b": 0.0},
            "ppm": 0.0,
            "residuals": {"mae": 0.0, "p90": 0.0},
            "inlier_frac": 0.0,
            "n_pairs": len(pairs) if pairs else 0,
        }

    if model != "affine":
        raise ValueError(f"Unsupported model: {model}")

    # Deterministic sampling / RANSAC: seed iff provided
    if rng_seed is not None:
        random.seed(rng_seed)

    # Cap runtime: sample at most 20k pairs deterministically
    pairs_list = list(pairs)
    if len(pairs_list) > 20_000:
        pairs_list = random.sample(pairs_list, 20_000)
        # Re-seed again for RANSAC to keep determinism stable
        if rng_seed is not None:
            random.seed(rng_seed)

    t_new = [p[0] for p in pairs_list]
    t_ref = [p[1] for p in pairs_list]
    n = len(pairs_list)

    # RANSAC two-point model
    best_a, best_b = 1.0, 0.0
    best_inliers: List[int] = []

    for _ in range(ransac_iters):
        i, j = random.randrange(n), random.randrange(n)
        if i == j:
            continue
        x1, y1 = t_new[i], t_ref[i]
        x2, y2 = t_new[j], t_ref[j]
        denom = x2 - x1
        if abs(denom) < 1e-9:
            continue

        a = (y2 - y1) / denom
        b = y1 - a * x1

        inliers: List[int] = []
        for k in range(n):
            pred = a * t_new[k] + b
            residual = abs(t_ref[k] - pred)
            if residual <= inlier_thresh:  # inclusive
                inliers.append(k)

        if len(inliers) > len(best_inliers):
            best_inliers = inliers
            best_a, best_b = a, b

    # If RANSAC failed to find inliers, do an LS fallback on all data
    if len(best_inliers) == 0:
        a, b = _least_squares_affine(t_new, t_ref)
        inlier_frac = 0.0
    elif len(best_inliers) >= 2:
        a, b = _irls_affine_huber(
            [t_new[k] for k in best_inliers],
            [t_ref[k] for k in best_inliers],
            init_a=best_a,
            init_b=best_b,
            huber_delta=inlier_thresh,
            max_iters=10,
        )
        inlier_frac = len(best_inliers) / n
    else:
        a, b = best_a, best_b
        inlier_frac = len(best_inliers) / n

    # Residual stats on all pairs (compute once, reuse for mae/p90)
    preds = [a * tn + b for tn in t_new]
    residuals_all = [abs(tr - pr) for tr, pr in zip(t_ref, preds)]
    mae = sum(residuals_all) / n if n else 0.0
    p90 = _percentile(residuals_all, 90)

    ppm = abs(a - 1.0) * 1e6
    ok = (ppm <= max_ppm) and (inlier_frac >= min_inlier_frac)

    warp = TimeWarp(model=model, a=a, b=b)
    return {
        "warp": warp,
        "model": model,
        "ok": ok,
        "params": {"a": a, "b": b},
        "ppm": ppm,
        "residuals": {"mae": mae, "p90": p90},
        "inlier_frac": inlier_frac,
        "n_pairs": n,
    }


def gate_by_warp(
    candidates: Sequence[Tuple[float, float, ...]],
    warp: TimeWarp,
    window_sec: float,
) -> List[Tuple[float, float, ...]]:
    """Filter candidates by proximity to time-warp prediction.

    Args:
        candidates: Sequence of (t_ref, t_new, ...) tuples (length >= 2).
            NOTE: This is the OPPOSITE order from fit_timewarp() pairs.
        warp: TimeWarp model to predict t_ref from t_new.
        window_sec: Acceptance window in seconds (L1 distance).

    Returns:
        Filtered list of candidates whose |t_ref - f(t_new)| <= window_sec.
    """
    out: List[Tuple[float, float, ...]] = []
    for cand in candidates:
        if len(cand) < 2:
            continue
        t_ref, t_new = cand[0], cand[1]
        if abs(t_ref - warp.f(t_new)) <= window_sec:
            out.append(cand)
    return out


# ---- Helpers ----------------------------------------------------------------

def _least_squares_affine(t_new: List[float], t_ref: List[float]) -> Tuple[float, float]:
    """Unweighted least-squares fit of y = a x + b."""
    n = len(t_new)
    if n == 0:
        return 1.0, 0.0
    sx = sum(t_new)
    sy = sum(t_ref)
    sxx = sum(x * x for x in t_new)
    sxy = sum(x * y for x, y in zip(t_new, t_ref))
    det = n * sxx - sx * sx
    if abs(det) < 1e-12:
        return 1.0, 0.0
    a = (n * sxy - sx * sy) / det
    b = (sxx * sy - sx * sxy) / det
    return a, b


def _irls_affine_huber(
    t_new: List[float],
    t_ref: List[float],
    init_a: float,
    init_b: float,
    huber_delta: float,
    max_iters: int = 10,
) -> Tuple[float, float]:
    """IRLS refinement with Huber loss for affine model."""
    if huber_delta <= 1e-9:
        huber_delta = 1e-3

    a, b = init_a, init_b
    n = len(t_new)
    # Prebind locals to reduce list rebuilds
    rs = [0.0] * n
    ws = [0.0] * n

    for _ in range(max_iters):
        # residuals r = (a x + b) - y
        for i in range(n):
            rs[i] = (a * t_new[i] + b) - t_ref[i]

        # Huber weights
        for i in range(n):
            ar = abs(rs[i])
            ws[i] = 1.0 if ar <= huber_delta else huber_delta / (ar + 1e-12)

        sw = sum(ws)
        swx = sum(w * x for w, x in zip(ws, t_new))
        swy = sum(w * y for w, y in zip(ws, t_ref))
        swxx = sum(w * x * x for w, x in zip(ws, t_new))
        swxy = sum(w * x * y for w, x, y in zip(ws, t_new, t_ref))

        det = (sw * swxx - swx * swx)
        if abs(det) < 1e-12:
            break

        a_new = (sw * swxy - swx * swy) / det
        b_new = (swxx * swy - swx * swxy) / det

        if abs(a_new - a) < 1e-9 and abs(b_new - b) < 1e-9:
            a, b = a_new, b_new
            break
        a, b = a_new, b_new

    return a, b


def _percentile(values: List[float], p: float) -> float:
    """Return the p-th percentile (0..100) using linear interpolation."""
    if not values:
        return 0.0
    vals = sorted(values)
    n = len(vals)
    rank = (p / 100.0) * (n - 1)
    lo = int(rank)
    hi = min(lo + 1, n - 1)
    if lo == hi:
        return vals[lo]
    frac = rank - lo
    return vals[lo] * (1 - frac) + vals[hi] * frac
