"""Local refinement of coarse pairs using RANSAC geometry estimation."""

import numpy as np
import pandas as pd
import cv2
from typing import List, Tuple, Optional, Literal


def sigmoid(x: float) -> float:
    """Compute sigmoid function for confidence calculation."""
    return 1.0 / (1.0 + np.exp(-x))


def match_local_features(
    kpts_new: np.ndarray,
    desc_new: np.ndarray,
    kpts_ref: np.ndarray,
    desc_ref: np.ndarray,
    lowe_ratio: float = 0.75,
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Match features between two frames using BFMatcher and Lowe ratio test.

    Args:
        kpts_new: New frame keypoints (Nx5 float32)
        desc_new: New frame descriptors (Nx32 uint8)
        kpts_ref: Reference frame keypoints (Mx5 float32)
        desc_ref: Reference frame descriptors (Mx32 uint8)
        lowe_ratio: Lowe ratio test threshold

    Returns:
        Tuple of (pts_new, pts_ref) where each is Kx2 float32 array of matched points
        Returns empty arrays if no good matches
    """
    if desc_new is None or len(desc_new) == 0 or desc_ref is None or len(desc_ref) == 0:
        return np.zeros((0, 2), dtype=np.float32), np.zeros((0, 2), dtype=np.float32)

    # BFMatcher with Hamming distance
    bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=False)

    try:
        matches = bf.knnMatch(desc_new, desc_ref, k=2)
    except cv2.error:
        return np.zeros((0, 2), dtype=np.float32), np.zeros((0, 2), dtype=np.float32)

    # Apply Lowe ratio test
    good_matches = []
    for match_pair in matches:
        if len(match_pair) == 2:
            m, n = match_pair
            if m.distance < lowe_ratio * n.distance:
                good_matches.append(m)
        elif len(match_pair) == 1:
            good_matches.append(match_pair[0])

    if len(good_matches) == 0:
        return np.zeros((0, 2), dtype=np.float32), np.zeros((0, 2), dtype=np.float32)

    # Extract matched point coordinates (x, y)
    pts_new = np.array([kpts_new[m.queryIdx, :2] for m in good_matches], dtype=np.float32)
    pts_ref = np.array([kpts_ref[m.trainIdx, :2] for m in good_matches], dtype=np.float32)

    return pts_new, pts_ref


def compute_homography_ransac(
    pts_new: np.ndarray,
    pts_ref: np.ndarray,
    ransac_thresh: float = 3.0,
    min_inliers: int = 4,
) -> Tuple[Optional[np.ndarray], np.ndarray, float]:
    """
    Compute homography using RANSAC.

    Args:
        pts_new: Nx2 array of points in new frame
        pts_ref: Nx2 array of corresponding points in reference frame
        ransac_thresh: RANSAC reprojection threshold in pixels
        min_inliers: Minimum number of inliers required

    Returns:
        Tuple of (H, inlier_mask, median_error) where:
        - H: 3x3 homography matrix (or None if failed)
        - inlier_mask: Boolean array indicating inliers
        - median_error: Median reprojection error of inliers (or inf if failed)
    """
    if len(pts_new) < 4 or len(pts_ref) < 4:
        return None, np.zeros(len(pts_new), dtype=bool), np.inf

    try:
        H, mask = cv2.findHomography(
            pts_new, pts_ref, cv2.RANSAC, ransacReprojThreshold=ransac_thresh
        )
    except cv2.error:
        return None, np.zeros(len(pts_new), dtype=bool), np.inf

    if H is None:
        return None, np.zeros(len(pts_new), dtype=bool), np.inf

    # Convert mask to boolean
    inlier_mask = mask.ravel().astype(bool)
    n_inliers = np.sum(inlier_mask)

    if n_inliers < min_inliers:
        return None, inlier_mask, np.inf

    # Compute reprojection errors for inliers
    pts_new_inliers = pts_new[inlier_mask]
    pts_ref_inliers = pts_ref[inlier_mask]

    # Transform points using homography
    pts_new_homog = np.hstack([pts_new_inliers, np.ones((n_inliers, 1))])
    pts_transformed = (H @ pts_new_homog.T).T
    pts_transformed = pts_transformed[:, :2] / pts_transformed[:, 2:3]

    # Compute errors
    errors = np.linalg.norm(pts_transformed - pts_ref_inliers, axis=1)
    median_error = np.median(errors)

    return H, inlier_mask, median_error


def compute_affine_ransac(
    pts_new: np.ndarray,
    pts_ref: np.ndarray,
    ransac_thresh: float = 3.0,
    min_inliers: int = 3,
) -> Tuple[Optional[np.ndarray], np.ndarray, float]:
    """
    Compute affine transform using RANSAC.

    Args:
        pts_new: Nx2 array of points in new frame
        pts_ref: Nx2 array of corresponding points in reference frame
        ransac_thresh: RANSAC reprojection threshold in pixels
        min_inliers: Minimum number of inliers required

    Returns:
        Tuple of (A, inlier_mask, median_error) where:
        - A: 2x3 affine matrix (or None if failed)
        - inlier_mask: Boolean array indicating inliers
        - median_error: Median reprojection error of inliers (or inf if failed)
    """
    if len(pts_new) < 3 or len(pts_ref) < 3:
        return None, np.zeros(len(pts_new), dtype=bool), np.inf

    try:
        A, mask = cv2.estimateAffine2D(
            pts_new, pts_ref, method=cv2.RANSAC, ransacReprojThreshold=ransac_thresh
        )
    except cv2.error:
        return None, np.zeros(len(pts_new), dtype=bool), np.inf

    if A is None:
        return None, np.zeros(len(pts_new), dtype=bool), np.inf

    # Convert mask to boolean
    inlier_mask = mask.ravel().astype(bool)
    n_inliers = np.sum(inlier_mask)

    if n_inliers < min_inliers:
        return None, inlier_mask, np.inf

    # Compute reprojection errors for inliers
    pts_new_inliers = pts_new[inlier_mask]
    pts_ref_inliers = pts_ref[inlier_mask]

    # Transform points using affine
    pts_new_homog = np.hstack([pts_new_inliers, np.ones((n_inliers, 1))])
    pts_transformed = (A @ pts_new_homog.T).T

    # Compute errors
    errors = np.linalg.norm(pts_transformed - pts_ref_inliers, axis=1)
    median_error = np.median(errors)

    return A, inlier_mask, median_error


def refine_single_candidate(
    kpts_new: np.ndarray,
    desc_new: np.ndarray,
    kpts_ref: np.ndarray,
    desc_ref: np.ndarray,
    lowe_ratio: float = 0.75,
    ransac_thresh: float = 3.0,
    min_inliers: int = 20,
    conf_a: float = 0.06,
    conf_b: float = 0.8,
) -> Tuple[float, int, float, Literal["H", "A", "NONE"]]:
    """
    Refine a single candidate pair using local feature matching and RANSAC.

    Args:
        kpts_new: New frame keypoints (Nx5)
        desc_new: New frame descriptors (Nx32)
        kpts_ref: Reference frame keypoints (Mx5)
        desc_ref: Reference frame descriptors (Mx32)
        lowe_ratio: Lowe ratio test threshold
        ransac_thresh: RANSAC reprojection threshold
        min_inliers: Minimum inliers required
        conf_a: Confidence sigmoid parameter (inlier weight)
        conf_b: Confidence sigmoid parameter (error weight)

    Returns:
        Tuple of (confidence, n_inliers, reproj_error, model)
        Returns (0.0, 0, inf, "NONE") if refinement fails
    """
    # Match features
    pts_new, pts_ref = match_local_features(
        kpts_new, desc_new, kpts_ref, desc_ref, lowe_ratio=lowe_ratio
    )

    if len(pts_new) < min_inliers:
        return 0.0, 0, np.inf, "NONE"

    # Try homography first
    H, mask_h, error_h = compute_homography_ransac(
        pts_new, pts_ref, ransac_thresh=ransac_thresh, min_inliers=min_inliers
    )

    if H is not None:
        n_inliers = np.sum(mask_h)
        confidence = sigmoid(conf_a * n_inliers - conf_b * error_h)
        return confidence, n_inliers, error_h, "H"

    # Fall back to affine
    A, mask_a, error_a = compute_affine_ransac(
        pts_new, pts_ref, ransac_thresh=ransac_thresh, min_inliers=min_inliers
    )

    if A is not None:
        n_inliers = np.sum(mask_a)
        confidence = sigmoid(conf_a * n_inliers - conf_b * error_a)
        return confidence, n_inliers, error_a, "A"

    return 0.0, 0, np.inf, "NONE"


def refine_pairs_locally(
    pairs_df: pd.DataFrame,
    new_timestamps: np.ndarray,
    new_kpts_list: List[np.ndarray],
    new_desc_list: List[np.ndarray],
    ref_timestamps: np.ndarray,
    ref_kpts_list: List[np.ndarray],
    ref_desc_list: List[np.ndarray],
    fps: float = 3.0,
    refine_window: int = 1,
    lowe_ratio: float = 0.75,
    ransac_thresh: float = 3.0,
    min_inliers: int = 20,
    conf_a: float = 0.06,
    conf_b: float = 0.8,
) -> pd.DataFrame:
    """
    Refine coarse pairs using local RANSAC geometry estimation.

    For each t_new, takes top candidates from pairs_df and searches within
    ±refine_window keyframes around each candidate. Computes homography or
    affine transform with RANSAC, and keeps the best match based on confidence.

    Args:
        pairs_df: Coarse pairs from retrieval (t_new, t_ref, score, n_matches)
        new_timestamps: New video timestamps
        new_kpts_list: New video keypoints
        new_desc_list: New video descriptors
        ref_timestamps: Reference video timestamps
        ref_kpts_list: Reference keypoints
        ref_desc_list: Reference descriptors
        fps: Keyframe extraction rate (for window calculation)
        refine_window: Window size in keyframes (±N around candidate)
        lowe_ratio: Lowe ratio test threshold
        ransac_thresh: RANSAC reprojection threshold in pixels
        min_inliers: Minimum inliers required
        conf_a: Confidence sigmoid parameter (inlier weight)
        conf_b: Confidence sigmoid parameter (error weight)

    Returns:
        DataFrame with columns: t_new, t_ref, confidence, inliers, reproj_error, model
        Sorted by t_new (monotonic)
    """
    if pairs_df.empty:
        raise ValueError("Input pairs_df is empty")

    # Build timestamp to index mappings
    new_t_to_idx = {t: i for i, t in enumerate(new_timestamps)}
    ref_t_to_idx = {t: i for i, t in enumerate(ref_timestamps)}

    refined_pairs = []

    # Group by t_new to process each new frame
    for t_new, group in pairs_df.groupby("t_new", sort=True):
        if t_new not in new_t_to_idx:
            continue

        new_idx = new_t_to_idx[t_new]
        kpts_new = new_kpts_list[new_idx]
        desc_new = new_desc_list[new_idx]

        if desc_new is None or len(desc_new) == 0:
            continue

        best_result = None
        best_confidence = 0.0

        # Try each candidate reference frame
        for _, row in group.iterrows():
            t_ref_center = row["t_ref"]

            if t_ref_center not in ref_t_to_idx:
                continue

            ref_idx_center = ref_t_to_idx[t_ref_center]

            # Search within ±refine_window
            window_start = max(0, ref_idx_center - refine_window)
            window_end = min(len(ref_timestamps), ref_idx_center + refine_window + 1)

            for ref_idx in range(window_start, window_end):
                t_ref = ref_timestamps[ref_idx]
                kpts_ref = ref_kpts_list[ref_idx]
                desc_ref = ref_desc_list[ref_idx]

                if desc_ref is None or len(desc_ref) == 0:
                    continue

                # Refine this candidate
                confidence, n_inliers, reproj_error, model = refine_single_candidate(
                    kpts_new,
                    desc_new,
                    kpts_ref,
                    desc_ref,
                    lowe_ratio=lowe_ratio,
                    ransac_thresh=ransac_thresh,
                    min_inliers=min_inliers,
                    conf_a=conf_a,
                    conf_b=conf_b,
                )

                # Keep best match
                if confidence > best_confidence:
                    best_confidence = confidence
                    best_result = {
                        "t_new": t_new,
                        "t_ref": t_ref,
                        "confidence": confidence,
                        "inliers": n_inliers,
                        "reproj_error": reproj_error,
                        "model": model,
                    }

        # Add best match for this t_new (if any)
        if best_result is not None and best_result["model"] != "NONE":
            refined_pairs.append(best_result)

    if not refined_pairs:
        raise RuntimeError("No valid refined pairs found (all candidates failed RANSAC)")

    # Create DataFrame
    df = pd.DataFrame(refined_pairs)

    # Sort by t_new for monotonicity
    df = df.sort_values("t_new").reset_index(drop=True)

    return df


def save_refined_pairs(df: pd.DataFrame, output_path: str) -> None:
    """
    Save refined pairs DataFrame to CSV.

    Args:
        df: DataFrame with columns t_new, t_ref, confidence, inliers, reproj_error, model
        output_path: Path to output CSV file
    """
    df.to_csv(output_path, index=False, float_format="%.6f")
