"""Coarse visual retrieval for frame matching."""

import concurrent.futures
import json
import time
from typing import List, Optional, Tuple, Dict, Any
from pathlib import Path

import cv2
import numpy as np
import pandas as pd

from mtbsync.match.timewarp import fit_timewarp, gate_by_warp, TimeWarp


def load_reference_index(index_path: str) -> Tuple[np.ndarray, List[np.ndarray], List[np.ndarray]]:
    """
    Load reference index from .npz file.

    Args:
        index_path: Path to reference index .npz file

    Returns:
        Tuple of (timestamps, keypoints_list, descriptors_list)

    Raises:
        RuntimeError: If index file cannot be loaded or is invalid
    """
    try:
        data = np.load(index_path, allow_pickle=True)
    except Exception as e:
        raise RuntimeError(f"Failed to load reference index: {index_path}") from e

    if "t_ref" not in data or "kpts" not in data or "desc" not in data:
        raise RuntimeError(f"Invalid index file format: {index_path}")

    t_ref = data["t_ref"]
    kpts_list = list(data["kpts"])
    desc_list = list(data["desc"])

    if len(t_ref) != len(kpts_list) or len(t_ref) != len(desc_list):
        raise RuntimeError("Inconsistent lengths in index file")

    return t_ref, kpts_list, desc_list


def match_frame_to_reference(
    new_desc: np.ndarray,
    ref_desc_list: List[np.ndarray],
    ref_timestamps: np.ndarray,
    top_k: int = 3,
    lowe_ratio: float = 0.75,
    t_ref_center: Optional[float] = None,
    time_window_sec: Optional[float] = None,
) -> List[Tuple[float, float, int]]:
    """
    Match a single new frame to reference frames using brute-force matching.

    Args:
        new_desc: Descriptors for new frame (Nx32 uint8)
        ref_desc_list: List of reference frame descriptors
        ref_timestamps: Reference frame timestamps
        top_k: Number of top reference frames to return
        lowe_ratio: Lowe ratio test threshold
        t_ref_center: Optional GPS-estimated reference time (center of search window)
        time_window_sec: Optional time window (±seconds) around t_ref_center

    Returns:
        List of tuples (t_ref, score, n_matches) for top-K matches
        Empty list if no good matches found

    Raises:
        ValueError: If inputs are invalid
    """
    if new_desc is None or len(new_desc) == 0:
        return []

    if len(ref_desc_list) != len(ref_timestamps):
        raise ValueError("Reference descriptors and timestamps length mismatch")

    # Create BFMatcher with Hamming distance (for ORB)
    bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=False)

    # Store matches for each reference frame
    frame_matches = []

    for ref_idx, ref_desc in enumerate(ref_desc_list):
        t_ref = ref_timestamps[ref_idx]

        # GPS filtering: skip frames outside the time window
        if t_ref_center is not None and time_window_sec is not None:
            if abs(t_ref - t_ref_center) > time_window_sec:
                continue
        if ref_desc is None or len(ref_desc) == 0:
            # Skip frames with no descriptors
            continue

        # Match using knnMatch to get top 2 matches for Lowe ratio test
        try:
            matches = bf.knnMatch(new_desc, ref_desc, k=2)
        except cv2.error:
            # Can happen if one set has only 1 descriptor
            continue

        # Apply Lowe ratio test
        good_matches = []
        for match_pair in matches:
            if len(match_pair) == 2:
                m, n = match_pair
                if m.distance < lowe_ratio * n.distance:
                    good_matches.append(m)
            elif len(match_pair) == 1:
                # Only one match found (rare), accept it
                good_matches.append(match_pair[0])

        if len(good_matches) > 0:
            # Calculate score (average distance, lower is better)
            # We invert it so higher score is better
            avg_distance = np.mean([m.distance for m in good_matches])
            score = 1.0 / (1.0 + avg_distance)  # Normalized score

            frame_matches.append(
                {
                    "t_ref": ref_timestamps[ref_idx],
                    "score": score,
                    "n_matches": len(good_matches),
                    "ref_idx": ref_idx,
                }
            )

    if not frame_matches:
        return []

    # Sort by number of matches (primary) and score (secondary)
    frame_matches.sort(key=lambda x: (x["n_matches"], x["score"]), reverse=True)

    # Return top-K matches
    top_matches = frame_matches[:top_k]
    return [(m["t_ref"], m["score"], m["n_matches"]) for m in top_matches]


def _match_single_frame_worker(args: Tuple) -> List[Dict[str, Any]]:
    """
    Worker function for threading: match a single new frame to reference frames.

    Args:
        args: Tuple of (new_idx, t_new, new_desc, ref_desc_list, ref_timestamps,
                        top_k, lowe_ratio, t_ref_center, gps_window_sec)

    Returns:
        List of match dictionaries with keys: t_new, t_ref, score, n_matches
    """
    (new_idx, t_new, new_desc, ref_desc_list, ref_timestamps,
     top_k, lowe_ratio, t_ref_center, gps_window_sec) = args

    # Skip if this frame has no descriptors
    if new_desc is None or len(new_desc) == 0:
        return []

    # Find top-K reference matches (GPS-constrained if provided)
    matches = match_frame_to_reference(
        new_desc,
        ref_desc_list,
        ref_timestamps,
        top_k=top_k,
        lowe_ratio=lowe_ratio,
        t_ref_center=t_ref_center,
        time_window_sec=gps_window_sec,
    )

    # Convert to pair dictionaries
    frame_pairs = []
    for t_ref, score, n_matches in matches:
        frame_pairs.append({
            "t_new": t_new,
            "t_ref": t_ref,
            "score": score,
            "n_matches": n_matches,
        })
    return frame_pairs


def retrieve_coarse_pairs(
    new_timestamps: np.ndarray,
    new_desc_list: List[np.ndarray],
    ref_timestamps: np.ndarray,
    ref_desc_list: List[np.ndarray],
    top_k: int = 3,
    lowe_ratio: float = 0.75,
    t_ref_est: Optional[np.ndarray] = None,
    gps_window_sec: Optional[float] = None,
    warp_enable: bool = True,
    warp_window_sec: float = 1.0,
    warp_inlier_thresh: float = 0.08,
    warp_ransac_iters: int = 500,
    warp_max_ppm: int = 1000,
    warp_min_inlier_frac: float = 0.25,
    warp_rng_seed: int | None = 0,
    threads: int = 1,
) -> Tuple[pd.DataFrame, Dict[str, float]]:
    """
    Perform coarse retrieval matching for all new frames.

    For each new frame, find top-K reference frames using descriptor matching.
    Optionally restrict search to GPS-estimated time windows.

    Args:
        new_timestamps: Timestamps for new frames
        new_desc_list: Descriptors for new frames
        ref_timestamps: Timestamps for reference frames
        ref_desc_list: Descriptors for reference frames
        top_k: Number of top matches to keep per new frame
        lowe_ratio: Lowe ratio test threshold
        t_ref_est: Optional GPS-estimated reference timestamps for each new frame
        gps_window_sec: Optional time window (±seconds) for GPS-constrained search
        warp_enable: Enable time-warp gating
        warp_window_sec: Time-warp acceptance window in seconds
        warp_inlier_thresh: RANSAC inlier threshold in seconds
        warp_ransac_iters: Number of RANSAC iterations
        warp_max_ppm: Maximum allowed parts-per-million drift
        warp_min_inlier_frac: Minimum required fraction of inliers
        warp_rng_seed: Random seed for deterministic results (None = do not reseed)
        threads: Number of threads for parallel retrieval (1 = sequential)

    Returns:
        Tuple of (DataFrame with columns: t_new, t_ref, score, n_matches, timing dict)

    Raises:
        ValueError: If inputs are invalid
    """
    if len(new_timestamps) != len(new_desc_list):
        raise ValueError("New timestamps and descriptors length mismatch")

    if len(ref_timestamps) != len(ref_desc_list):
        raise ValueError("Reference timestamps and descriptors length mismatch")

    if t_ref_est is not None and len(t_ref_est) != len(new_timestamps):
        raise ValueError("GPS estimates length must match new timestamps")

    # Timing instrumentation
    timings: Dict[str, float] = {}
    t_total_start = time.time()

    # GPS phase timing (if applicable - note: GPS runs before retrieval in CLI)
    # We'll capture retrieval timing here
    t_retrieval_start = time.time()

    pairs = []

    # Prepare worker arguments
    worker_args = []
    for new_idx, (t_new, new_desc) in enumerate(zip(new_timestamps, new_desc_list)):
        t_ref_center = t_ref_est[new_idx] if t_ref_est is not None else None
        worker_args.append((
            new_idx, t_new, new_desc, ref_desc_list, ref_timestamps,
            top_k, lowe_ratio, t_ref_center, gps_window_sec
        ))

    # Multi-threaded or sequential retrieval
    if threads > 1:
        with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as executor:
            future_to_idx = {executor.submit(_match_single_frame_worker, args): args[0]
                             for args in worker_args}
            for future in concurrent.futures.as_completed(future_to_idx):
                frame_pairs = future.result()
                pairs.extend(frame_pairs)
    else:
        # Sequential path
        for args in worker_args:
            frame_pairs = _match_single_frame_worker(args)
            pairs.extend(frame_pairs)

    timings["retrieval_sec"] = time.time() - t_retrieval_start

    if not pairs:
        raise RuntimeError("No valid matches found between videos")

    # Create DataFrame
    df = pd.DataFrame(pairs)

    # --- Time-Warp gating ----------------------------------------------------
    # Order: GPS gating (if any) runs first, then time-warp gating.
    t_warp_start = time.time()
    if warp_enable and len(df) >= 8:
        try:
            # Convert DataFrame to candidate tuples (t_ref, t_new, score, n_matches)
            candidates = [(row["t_ref"], row["t_new"], row["score"], row["n_matches"])
                         for _, row in df.iterrows()]

            # Sample up to 20k pairs (t_ref, t_new) → swap for fit_timewarp format
            pairs_for_fit = [(t_new, t_ref) for t_ref, t_new, *rest in (candidates[:20000])]
            fit = fit_timewarp(
                pairs_for_fit,
                model="affine",
                inlier_thresh=warp_inlier_thresh,
                ransac_iters=warp_ransac_iters,
                max_ppm=warp_max_ppm,
                min_inlier_frac=warp_min_inlier_frac,
                rng_seed=warp_rng_seed,
            )
            if fit.get("ok"):
                warp = fit["warp"]
                before = len(candidates)
                candidates = gate_by_warp(candidates, warp, warp_window_sec)
                after = len(candidates)
                reduction = 100.0 * (1 - after / max(before, 1))
                print(
                    f"[Time-Warp] a={warp.a:.8f}, b={warp.b:.3f}, "
                    f"ppm={fit['ppm']:.0f}, inliers={fit['inlier_frac']:.2%}, "
                    f"window=±{warp_window_sec:.2f}s → {after}/{before} kept ({reduction:.1f}% pruned)"
                )

                # Save artefact near output path if possible
                out_dir = Path(".")
                meta = {
                    "model": fit["model"],
                    "params": fit["params"],
                    "ppm": fit["ppm"],
                    "inlier_frac": fit["inlier_frac"],
                    "residuals": fit["residuals"],
                    "ok": fit["ok"],
                    "window_sec": warp_window_sec,
                    "n_pairs": fit["n_pairs"],
                }
                try:
                    with open(out_dir / "timewarp.json", "w", encoding="utf-8") as f:
                        json.dump(meta, f, indent=2)
                except Exception as e:
                    print(f"[Time-Warp] Warning: could not write timewarp.json ({e})")

                # Convert filtered candidates back to DataFrame
                df = pd.DataFrame([
                    {"t_new": t_new, "t_ref": t_ref, "score": score, "n_matches": n_matches}
                    for t_ref, t_new, score, n_matches in candidates
                ])
            else:
                print("[Time-Warp] Fit quality insufficient; skipped gating.")
        except Exception as e:
            print(f"[Time-Warp] Error during fitting: {e}")
    timings["warp_sec"] = time.time() - t_warp_start
    # -------------------------------------------------------------------------

    # --- Marker Transfer Auto-Export (Optional) ---
    t_markers_start = time.time()
    try:
        from mtbsync.match.marker_transfer import transfer_markers, load_timewarp_json
        ref_markers_path = Path(".") / "ref_markers.csv"
        if ref_markers_path.exists():
            tw_json = Path(".") / "timewarp.json"
            out_markers = Path(".") / "new_markers_auto.csv"
            transfer_markers(ref_markers_path, tw_json, out_markers)
            print(f"[Marker-Transfer] Auto-exported markers → {out_markers}")
    except Exception as e:
        print(f"[Marker-Transfer] Skipped (auto-export error: {e})")
    timings["markers_sec"] = time.time() - t_markers_start
    # ------------------------------------------------

    # Sort by t_new for readability
    df = df.sort_values("t_new").reset_index(drop=True)

    # Record total time and frame count (for telemetry FPS calculation)
    timings["total_sec"] = time.time() - t_total_start
    timings["frames_processed"] = len(new_timestamps)

    return df, timings


def save_pairs_csv(df: pd.DataFrame, output_path: str) -> None:
    """
    Save pairs DataFrame to CSV.

    Args:
        df: DataFrame with columns t_new, t_ref, score, n_matches
        output_path: Path to output CSV file
    """
    df.to_csv(output_path, index=False, float_format="%.6f")
