"""
extractEmbedding.py — Modular embedding extraction for ViT-based models
-----------------------------------------------------------------------
Provides a single entry point `extractEmbedding` focused on computing patch
or token-level embeddings from H&E inputs (WSI via OpenSlide or regular
images via PIL). The module is model-agnostic via a simple adapter
interface and supports TIMM ViT models out of the box.

This module reuses tiling & batching utilities from image_patcher.py.
"""
from __future__ import annotations

from typing import Optional, Tuple, List, Dict, Any, Union

import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as T
from pathlib import Path

# Optional heavy deps for creating models
_HAS_TIMM = True
try:
    import timm  # type: ignore
except Exception:
    _HAS_TIMM = False

# Local I/O helpers
from .utils import open_he_input as _open_he_input
# Reuse tiling + inference helpers (expects an object with .preprocess, .forward_tokens, .forward_patch, and .to(device))
from .image_patcher import (
    infer_embeddings_wsi,
    infer_embeddings_wsi_stream,
    infer_embeddings_image,
    infer_embeddings_image_stream,
)


def _get_device(name: str) -> torch.device:
    name = (name or "cpu").lower()
    if name == "auto":
        if torch.cuda.is_available():
            return torch.device("cuda")
        if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
            return torch.device("mps")
        return torch.device("cpu")
    return torch.device(name)


class BaseViTAdapter(nn.Module):
    """
    Minimal adapter wrapping a TIMM ViT-like model.
    Exposes:
    - preprocess: torchvision transforms
    - forward_tokens: returns (B, T, D) tokens per patch
    - forward_patch: returns (B, D) pooled patch embedding
    """
    def __init__(self, model: nn.Module, input_size: Tuple[int, int], mean: List[float], std: List[float]):
        super().__init__()
        self.model = model
        H, W = int(input_size[0]), int(input_size[1])
        self.preprocess = T.Compose([
            T.Resize((H, W)),
            T.ToTensor(),
            T.Normalize(mean=list(map(float, mean)), std=list(map(float, std))),
        ])

    def _pos_embed_resize(self, x: torch.Tensor, pe: torch.Tensor, has_cls: bool) -> torch.Tensor:
        # Attempt to adapt positional embeddings to the current grid size.
        if x.shape[1] == pe.shape[1]:
            return x + pe
        try:
            import torch.nn.functional as F
            if has_cls:
                cls_pe, grid_pe = pe[:, :1], pe[:, 1:]
            else:
                cls_pe, grid_pe = None, pe
            N = x.shape[1] - (1 if has_cls else 0)
            H = W = int(np.sqrt(N))
            hw = int(np.sqrt(grid_pe.shape[1]))
            grid_pe = grid_pe.reshape(1, hw, hw, -1).permute(0, 3, 1, 2)
            grid_pe = F.interpolate(grid_pe, size=(H, W), mode="bicubic", align_corners=False)
            grid_pe = grid_pe.permute(0, 2, 3, 1).reshape(1, H * W, -1)
            if has_cls:
                x = x + torch.cat((cls_pe, grid_pe), dim=1)
            else:
                x = x + grid_pe
        except Exception:
            pass
        return x

    @torch.inference_mode()
    def forward_tokens(self, batch: torch.Tensor) -> torch.Tensor:
        m = self.model
        # Prefer forward_features, fallback as necessary
        try:
            if hasattr(m, "forward_features"):
                out = m.forward_features(batch)
                x = out["x"] if isinstance(out, dict) and "x" in out else out
            elif hasattr(m, "forward_tokens"):
                x = m.forward_tokens(batch)
            else:
                x = m(batch)
        except Exception:
            x = m(batch)
        if isinstance(x, torch.Tensor):
            if x.ndim == 3:
                # x includes CLS at position 0 for many ViT models
                # Try to add positional emb resize when available
                pe = getattr(m, "pos_embed", None)
                if pe is not None and isinstance(pe, torch.Tensor):
                    x = self._pos_embed_resize(x, pe, has_cls=True)
                return x
            if x.ndim == 4:
                # conv-stem style: (B, D, H, W) → (B, H*W, D)
                B, D, H, W = x.shape
                return x.permute(0, 2, 3, 1).reshape(B, H * W, D)
        raise RuntimeError("Unsupported tensor shape from model for tokens.")

    @torch.inference_mode()
    def forward_patch(self, batch: torch.Tensor) -> torch.Tensor:
        m = self.model
        try:
            feats = m.forward_features(batch)
            if isinstance(feats, dict):
                if "pooled" in feats and isinstance(feats["pooled"], torch.Tensor):
                    return feats["pooled"]
                if "x" in feats and isinstance(feats["x"], torch.Tensor):
                    x = feats["x"]
                    if x.ndim == 3:
                        return x[:, 0, :]  # CLS token
                    if x.ndim == 4:
                        return x.mean(dim=[2, 3])
            if isinstance(feats, torch.Tensor):
                if feats.ndim == 2:
                    return feats
                if feats.ndim == 3:
                    return feats[:, 0, :]
                if feats.ndim == 4:
                    return feats.mean(dim=[2, 3])
        except Exception:
            out = m(batch)
            if isinstance(out, torch.Tensor):
                if out.ndim == 2:
                    return out
                if out.ndim == 3:
                    return out[:, 0, :]
                if out.ndim == 4:
                    return out.mean(dim=[2, 3])
        raise RuntimeError("Could not derive pooled patch embedding from model output.")


def make_vit_adapter(
    model_name: str,
    input_size: Tuple[int, int],
    mean: List[float],
    std: List[float],
    timm_kwargs: Optional[Dict[str, Any]] = None,
) -> BaseViTAdapter:
    if not _HAS_TIMM:
        raise RuntimeError("timm is required to create ViT models.")
    model = timm.create_model(model_name, **(timm_kwargs or {}))
    # Basic sanity: ViT-like
    for r in ("patch_embed", "blocks", "norm"):
        if not hasattr(model, r):
            # Not strictly required for all ViT variants, but a guard for obvious mismatches
            break
    return BaseViTAdapter(model, input_size=input_size, mean=mean, std=std)


def extractEmbedding(
    input_obj: Union[str, "Image.Image", Any],
    model_name: str,
    input_size: Tuple[int, int],
    mean: List[float],
    std: List[float],
    patch_size: int,
    stride: Optional[int] = None,
    output_type: Union[str, List[str]] = "patch",  # "patch" | "token" | ["patch","token"]
    batch_size: int = 32,
    device: str = "auto",
    stream_write: bool = False,
    target_mpp: Optional[float] = None,
    source_mpp: Optional[float] = None,
    tissue_polys: Optional[List[Any]] = None,
    filter_to_tissue: bool = False,
    timm_kwargs: Optional[Dict[str, Any]] = None,
    slide_id: Optional[str] = None,
) -> Dict[str, Any]:
    # Build adapter
    vit = make_vit_adapter(model_name, input_size=input_size, mean=mean, std=std, timm_kwargs=timm_kwargs)
    dev = _get_device(device)

    # Normalize requested output types
    req = output_type if isinstance(output_type, list) else [output_type]
    req = [str(t).lower() for t in req]
    allowed = {"patch", "token"}
    for t in req:
        if t not in allowed:
            raise ValueError(f"Unsupported output_type '{t}'. Allowed: {allowed}")

    # Prepare cfg expected by image_patcher helpers
    base_cfg: Dict[str, Any] = {
        "target_mpp": target_mpp,
        "filter_to_tissue": bool(filter_to_tissue),
    }

    # Interpret input
    he_input = _open_he_input(input_obj)
    slide_handle = he_input.slide_handle if he_input.kind == "wsi" else None
    image_obj = he_input.image_obj if he_input.kind == "image" else None
    detected_mpp = he_input.source_mpp if he_input.kind == "wsi" else None
    source_mpp_used = source_mpp if source_mpp is not None else detected_mpp

    # Derive slide_id if not provided
    auto_slide_id: Optional[str] = None
    if isinstance(input_obj, str):
        try:
            auto_slide_id = Path(input_obj).stem
        except Exception:
            auto_slide_id = None
    if auto_slide_id is None and slide_handle is not None:
        try:
            fname = getattr(slide_handle, "filename", None)
            if isinstance(fname, str):
                auto_slide_id = Path(fname).stem
        except Exception:
            auto_slide_id = None
    slide_id_out = slide_id if slide_id is not None else auto_slide_id

    results: Dict[str, Any] = {}

    # WSI path
    if slide_handle is not None:
        base_cfg["_source_mpp_for_run"] = source_mpp_used
        if target_mpp is not None and source_mpp_used is None:
            print("[mpp] target_mpp provided but source_mpp missing; defaulting to no rescale.")
        # Always operate from level 0 with downsample=1.0; image_patcher computes target-mpp frame
        coords_ref, boxes_ref = None, None
        for t in req:
            cfg = dict(base_cfg)
            cfg["output_type"] = t
            if stream_write:
                coords, embeds_da, boxes_T, token_dim = infer_embeddings_wsi_stream(
                    slide_handle, 0, 1.0, patch_size, stride or patch_size, vit, cfg, batch_size, device=dev,
                    tissue_polys=tissue_polys,
                )
                if t == "patch":
                    results["embeddings_patch"] = embeds_da
                else:
                    results["embeddings_token"] = embeds_da
                    results["token_dim"] = token_dim
            else:
                coords, embeds, boxes_T, token_dim = infer_embeddings_wsi(
                    slide_handle, 0, 1.0, patch_size, stride or patch_size, vit, cfg, batch_size, device=dev,
                )
                if t == "patch":
                    results["embeddings_patch"] = embeds
                else:
                    results["embeddings_token"] = embeds
                    results["token_dim"] = token_dim
            if coords_ref is None:
                coords_ref, boxes_ref = coords, boxes_T
        # Common outputs
        results.update({
            "coords": coords_ref,
            "boxes": boxes_ref,
            "frame": "target_mpp",
            "meta": {
                "model_name": model_name,
                "patch_size": patch_size,
                "stride": stride or patch_size,
                "target_mpp": target_mpp,
                "source_mpp": source_mpp_used,
                "slide_id": slide_id_out,
            },
        })
    else:
        # Regular image path
        assert image_obj is not None, "Input must be a WSI or a PIL Image."
        if tissue_polys is not None:
            base_cfg["_tissue_polys_image"] = tissue_polys
        coords_ref = None
        for t in req:
            cfg = dict(base_cfg)
            cfg["output_type"] = t
            if stream_write:
                coords, embeds_da, token_dim = infer_embeddings_image_stream(
                    image_obj, patch_size, stride or patch_size, vit, cfg, batch_size, device=dev,
                )
                if t == "patch":
                    results["embeddings_patch"] = embeds_da
                else:
                    results["embeddings_token"] = embeds_da
                    results["token_dim"] = token_dim
            else:
                coords, embeds, token_dim = infer_embeddings_image(
                    image_obj, patch_size, stride or patch_size, vit, cfg, batch_size, device=dev,
                )
                if t == "patch":
                    results["embeddings_patch"] = embeds
                else:
                    results["embeddings_token"] = embeds
                    results["token_dim"] = token_dim
            if coords_ref is None:
                coords_ref = coords
        results.update({
            "coords": coords_ref,
            "frame": "image_pixels",
            "meta": {
                "model_name": model_name,
                "patch_size": patch_size,
                "stride": stride or patch_size,
                "target_mpp": None,
                "source_mpp": None,
                "slide_id": slide_id_out,
            },
        })

    # Backwards-compatible single-output key
    if len(req) == 1:
        if req[0] == "patch":
            results["embeddings"] = results.get("embeddings_patch")
        else:
            results["embeddings"] = results.get("embeddings_token")
    return results


def embedImage(
    inputObj: Union[str, "Image.Image", Any],
    modelName: str,
    inputSize: Tuple[int, int],
    mean: List[float],
    std: List[float],
    patchSize: int,
    stride: Optional[int] = None,
    outputType: Union[str, List[str]] = "patch",
    batchSize: int = 32,
    device: str = "auto",
    streamWrite: bool = False,
    targetMpp: Optional[float] = None,
    sourceMpp: Optional[float] = None,
    tissuePolys: Optional[List[Any]] = None,
    filterToTissue: bool = False,
    timmKwargs: Optional[Dict[str, Any]] = None,
    slideId: Optional[str] = None,
) -> Dict[str, Any]:
    """
    Notebook-friendly alias for embedding extraction.
    Preferred name: embedImage.
    Maps camelCase parameters to the existing extractEmbedding implementation.
    """
    return extractEmbedding(
        input_obj=inputObj,
        model_name=modelName,
        input_size=inputSize,
        mean=mean,
        std=std,
        patch_size=patchSize,
        stride=stride,
        output_type=outputType,
        batch_size=batchSize,
        device=device,
        stream_write=streamWrite,
        target_mpp=targetMpp,
        source_mpp=sourceMpp,
        tissue_polys=tissuePolys,
        filter_to_tissue=filterToTissue,
        timm_kwargs=timmKwargs,
        slide_id=slideId,
    )


def extractEmbeddingCamel(
    inputObj: Union[str, "Image.Image", Any],
    modelName: str,
    inputSize: Tuple[int, int],
    mean: List[float],
    std: List[float],
    patchSize: int,
    stride: Optional[int] = None,
    outputType: Union[str, List[str]] = "patch",
    batchSize: int = 32,
    device: str = "auto",
    streamWrite: bool = False,
    targetMpp: Optional[float] = None,
    sourceMpp: Optional[float] = None,
    tissuePolys: Optional[List[Any]] = None,
    filterToTissue: bool = False,
    timmKwargs: Optional[Dict[str, Any]] = None,
    slideId: Optional[str] = None,
) -> Dict[str, Any]:
    """
    Backwards-compatible alias for embedImage. Prefer using embedImage.
    """
    return embedImage(
        inputObj=inputObj,
        modelName=modelName,
        inputSize=inputSize,
        mean=mean,
        std=std,
        patchSize=patchSize,
        stride=stride,
        outputType=outputType,
        batchSize=batchSize,
        device=device,
        streamWrite=streamWrite,
        targetMpp=targetMpp,
        sourceMpp=sourceMpp,
        tissuePolys=tissuePolys,
        filterToTissue=filterToTissue,
        timmKwargs=timmKwargs,
        slideId=slideId,
    )