"""Feature descriptor extraction using ORB."""

import json
from datetime import datetime, timezone
from typing import Dict, List, Tuple

import cv2
import numpy as np


def compute_orb_descriptors(
    frames: List[np.ndarray],
    n_features: int = 1500,
    use_clahe: bool = True,
) -> Tuple[List[np.ndarray], List[np.ndarray]]:
    """
    Compute ORB descriptors for a list of frames.

    Args:
        frames: List of BGR frames (numpy arrays)
        n_features: Maximum number of features to detect (default: 1500)
        use_clahe: Whether to apply CLAHE preprocessing (default: True)

    Returns:
        Tuple of (kpts_list, desc_list) where:
        - kpts_list: List of Nx5 float32 arrays [x, y, size, angle, response]
          Empty 0x5 array if no keypoints found
        - desc_list: List of Nx32 uint8 descriptor arrays
          Empty 0x32 array if no keypoints found

    Raises:
        ValueError: If frames list is empty or contains invalid frames
    """
    if not frames:
        raise ValueError("Frames list is empty")

    # Create ORB detector with specified parameters
    orb = cv2.ORB_create(
        nfeatures=n_features,
        scaleFactor=1.2,
        nlevels=8,
        edgeThreshold=31,
        firstLevel=0,
        WTA_K=2,
        scoreType=cv2.ORB_HARRIS_SCORE,
        patchSize=31,
        fastThreshold=20,
    )

    # Create CLAHE processor if needed
    clahe = None
    if use_clahe:
        clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))

    kpts_list = []
    desc_list = []

    for i, frame in enumerate(frames):
        if frame is None or frame.size == 0:
            raise ValueError(f"Frame {i} is invalid (None or empty)")

        # Convert to grayscale
        if len(frame.shape) == 3:
            gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        else:
            gray = frame

        # Apply CLAHE if requested
        if use_clahe and clahe is not None:
            gray = clahe.apply(gray)

        # Detect and compute
        keypoints, descriptors = orb.detectAndCompute(gray, None)

        # Convert keypoints to array format [x, y, size, angle, response]
        if keypoints and len(keypoints) > 0:
            kpts_array = np.array(
                [
                    [kp.pt[0], kp.pt[1], kp.size, kp.angle, kp.response]
                    for kp in keypoints
                ],
                dtype=np.float32,
            )
            desc_array = descriptors
        else:
            # Empty arrays with correct shape
            kpts_array = np.zeros((0, 5), dtype=np.float32)
            desc_array = np.zeros((0, 32), dtype=np.uint8)

        kpts_list.append(kpts_array)
        desc_list.append(desc_array)

    return kpts_list, desc_list


def save_reference_index(
    out_path: str,
    t_ref: np.ndarray,
    kpts_list: List[np.ndarray],
    desc_list: List[np.ndarray],
    meta: Dict,
) -> None:
    """
    Save reference index with timestamps, keypoints, descriptors, and metadata.

    Args:
        out_path: Path to output .npz file
        t_ref: Array of timestamps (seconds) for each frame
        kpts_list: List of keypoint arrays (Nx5 float32)
        desc_list: List of descriptor arrays (Nx32 uint8)
        meta: Metadata dictionary (will be JSON-serialized)

    Raises:
        ValueError: If inputs are invalid or inconsistent
    """
    if len(t_ref) != len(kpts_list) or len(t_ref) != len(desc_list):
        raise ValueError(
            f"Inconsistent lengths: t_ref={len(t_ref)}, "
            f"kpts={len(kpts_list)}, desc={len(desc_list)}"
        )

    # Add required metadata fields if not present
    required_meta = {
        "version": "1.0",
        "created_utc": datetime.now(timezone.utc).isoformat(),
        "opencv_version": cv2.__version__,
    }

    # Merge with provided meta (user meta takes precedence)
    full_meta = {**required_meta, **meta}

    # Serialize metadata to JSON
    meta_json = json.dumps(full_meta)

    # Save to compressed npz
    np.savez_compressed(
        out_path,
        t_ref=t_ref,
        kpts=np.array(kpts_list, dtype=object),
        desc=np.array(desc_list, dtype=object),
        meta=meta_json,
    )
