"""
image_patcher.py — Modular tiling and embedding helpers
-------------------------------------------------------
Provides reusable functions to generate patch grids and extract embeddings
for WSIs (via OpenSlide handle) and regular PIL Images. Supports both
non-streaming (NumPy) and streaming (Dask) variants.

Functions:
- generate_patch_grid(width, height, patch_size, stride)
- infer_embeddings_wsi(slide, level, downsample, patch_size, stride, vit, cfg, batch_size, device=None)
- infer_embeddings_wsi_stream(slide, level, downsample, patch_size, stride, vit, cfg, batch_size, device=None, tissue_polys=None)
- infer_embeddings_image(im, patch_size, stride, vit, cfg, batch_size, device=None)
- infer_embeddings_image_stream(im, patch_size, stride, vit, cfg, batch_size, device=None)
"""

from typing import List, Tuple, Optional, Dict, Any
import time

import numpy as np
import torch

# Optional deps
HAS_DASK = False
try:
    import dask.array as da
    HAS_DASK = True
except Exception:
    HAS_DASK = False

HAS_SHAPELY = False
try:
    import shapely.geometry as sg
    HAS_SHAPELY = True
except Exception:
    HAS_SHAPELY = False


def generate_patch_grid(width: int, height: int, patch_size: int, stride: Optional[int]) -> List[Tuple[int, int, int, int]]:
    if stride is None:
        stride = patch_size
    boxes: List[Tuple[int, int, int, int]] = []
    if width < patch_size or height < patch_size:
        return boxes
    # Primary grid
    xs = list(range(0, width - patch_size + 1, stride))
    ys = list(range(0, height - patch_size + 1, stride))
    # Ensure rightmost and bottommost coverage by adding tail tiles if needed
    right_start = width - patch_size
    bottom_start = height - patch_size
    if xs[-1] != right_start:
        xs.append(right_start)
    if ys[-1] != bottom_start:
        ys.append(bottom_start)
    for y in ys:
        for x in xs:
            boxes.append((x, y, x + patch_size, y + patch_size))
    return boxes


def infer_embeddings_wsi(
    slide,
    level: int,
    downsample: float,
    patch_size: int,
    stride: int,
    vit,
    cfg: Dict[str, Any],
    batch_size: int,
    device: Optional[torch.device] = None,
):
    device = device or torch.device("cpu")
    vit = vit.to(device).eval()
    # Always read from level 0 and rescale to target_mpp
    source_mpp = cfg.get("_source_mpp_for_run", None)
    target_mpp = cfg.get("target_mpp", None)
    if target_mpp is not None and source_mpp is not None:
        scale = float(target_mpp) / float(source_mpp)
    else:
        scale = float(downsample) if downsample > 0 else 1.0
    W0, H0 = slide.level_dimensions[0]
    width_T = int(W0 / scale)
    height_T = int(H0 / scale)
    boxes_T = generate_patch_grid(width_T, height_T, patch_size, stride)
    if not boxes_T:
        raise ValueError("No patches generated at target-mpp frame. Check patch size/stride.")
    total = len(boxes_T)
    print(f"Patching at level 0 (virtual target-mpp frame) → {total} tiles")
    coords: List[Tuple[int, int]] = []
    all_embeds: List[np.ndarray] = []
    to_tensor = vit.preprocess
    t0 = time.time()
    token_dim = None
    read_size = (max(1, int(round(patch_size * scale))), max(1, int(round(patch_size * scale))))
    for i in range(0, total, batch_size):
        batch_boxes = boxes_T[i : i + batch_size]
        patches = []
        for (l, t, r, b) in batch_boxes:
            loc0 = (int(round(l * scale)), int(round(t * scale)))
            patch = slide.read_region(loc0, 0, read_size).convert("RGB")
            patches.append(to_tensor(patch))
            coords.append((l, t))
        batch = torch.stack(patches, dim=0).to(device)
        if cfg["output_type"] == "token":
            toks = vit.forward_tokens(batch)
            B, T, D = toks.shape
            token_dim = D
            all_embeds.append(toks.reshape(B, T * D).cpu().numpy())
        else:
            feats = vit.forward_patch(batch)
            all_embeds.append(feats.cpu().numpy())
        if (i // batch_size) % 10 == 0 or i + batch_size >= total:
            pct = min(100, int(100 * (i + len(batch_boxes)) / total))
            elapsed = time.time() - t0
            print(f"Embedding tiles: {pct}% ({i + len(batch_boxes)}/{total})  elapsed {elapsed:.1f}s", end="\r")
    print()
    coords = np.array(coords, dtype=np.int32)
    embeds = np.concatenate(all_embeds, axis=0)
    return coords, embeds, boxes_T, token_dim


def infer_embeddings_wsi_stream(
    slide,
    level: int,
    downsample: float,
    patch_size: int,
    stride: int,
    vit,
    cfg: Dict[str, Any],
    batch_size: int,
    device: Optional[torch.device] = None,
    tissue_polys: Optional[List[Any]] = None,
):
    if not HAS_DASK:
        raise RuntimeError("Streaming write requires dask; please install 'dask[array]'.")
    device = device or torch.device("cpu")
    vit = vit.to(device).eval()
    source_mpp = cfg.get("_source_mpp_for_run", None)
    target_mpp = cfg.get("target_mpp", None)
    if target_mpp is not None and source_mpp is not None:
        scale = float(target_mpp) / float(source_mpp)
    else:
        scale = float(downsample) if downsample > 0 else 1.0
    W0, H0 = slide.level_dimensions[0]
    width_T = int(W0 / scale)
    height_T = int(H0 / scale)
    boxes_T = generate_patch_grid(width_T, height_T, patch_size, stride)
    if not boxes_T:
        raise ValueError("No patches generated at target-mpp frame. Check patch size/stride.")
    if tissue_polys is not None and hasattr(tissue_polys, "empty") and not tissue_polys.empty and bool(cfg.get("filter_to_tissue", True)):
        keep = []
        for (l, t, r, b) in boxes_T:
            cx = l + (r - l) * 0.5
            cy = t + (b - t) * 0.5
            pt = sg.Point(float(cx), float(cy)) if HAS_SHAPELY else (cx, cy)
            inside = False
            if HAS_SHAPELY:
                for poly in tissue_polys.geometry if hasattr(tissue_polys, "geometry") else tissue_polys:
                    try:
                        if poly.contains(pt):
                            inside = True
                            break
                    except Exception:
                        continue
            else:
                inside = True
            if inside:
                keep.append((l, t, r, b))
        boxes_T = keep
    total = len(boxes_T)
    print(f"Patching at level 0 (virtual target-mpp frame) → {total} tiles [stream]")
    coords: List[Tuple[int, int]] = []
    chunks: List[da.Array] = []
    to_tensor = vit.preprocess
    t0 = time.time()
    token_dim = None
    read_size = (max(1, int(round(patch_size * scale))), max(1, int(round(patch_size * scale))))
    for i in range(0, total, batch_size):
        batch_boxes = boxes_T[i : i + batch_size]
        patches = []
        for (l, t, r, b) in batch_boxes:
            loc0 = (int(round(l * scale)), int(round(t * scale)))
            patch = slide.read_region(loc0, 0, read_size).convert("RGB")
            patches.append(to_tensor(patch))
            coords.append((l, t))
        batch = torch.stack(patches, dim=0).to(device)
        if cfg["output_type"] == "token":
            toks = vit.forward_tokens(batch)
            B, T, D = toks.shape
            token_dim = D
            arr = toks.reshape(B, T * D).cpu().numpy()
        else:
            feats = vit.forward_patch(batch)
            arr = feats.cpu().numpy()
        chunks.append(da.from_array(arr, chunks=arr.shape))
        if (i // batch_size) % 10 == 0 or i + batch_size >= total:
            pct = min(100, int(100 * (i + len(batch_boxes)) / total))
            elapsed = time.time() - t0
            print(f"Embedding tiles: {pct}% ({i + len(batch_boxes)}/{total})  elapsed {elapsed:.1f}s", end="\r")
    print()
    coords = np.array(coords, dtype=np.int32)
    embeds_da = da.concatenate(chunks, axis=0)
    return coords, embeds_da, boxes_T, token_dim


def infer_embeddings_image(
    im,
    patch_size: int,
    stride: int,
    vit,
    cfg: Dict[str, Any],
    batch_size: int,
    device: Optional[torch.device] = None,
):
    device = device or torch.device("cpu")
    vit = vit.to(device).eval()
    width, height = im.size
    boxes = generate_patch_grid(width, height, patch_size, stride)
    tissue_polys = cfg.get("_tissue_polys_image", None)
    if tissue_polys and bool(cfg.get("filter_to_tissue", True)) and HAS_SHAPELY:
        keep = []
        for (l, t, r, b) in boxes:
            cx = l + (r - l) * 0.5
            cy = t + (b - t) * 0.5
            pt = sg.Point(float(cx), float(cy))
            inside = any(poly.contains(pt) for poly in tissue_polys)
            if inside:
                keep.append((l, t, r, b))
        boxes = keep
    if not boxes:
        raise ValueError("No patches generated. Check image size, patch size, and stride.")
    total = len(boxes)
    print(f"Patching regular image → {total} tiles")
    coords: List[Tuple[int, int]] = []
    all_embeds: List[np.ndarray] = []
    to_tensor = vit.preprocess
    t0 = time.time()
    token_dim = None
    for i in range(0, total, batch_size):
        batch_boxes = boxes[i : i + batch_size]
        patches = []
        for (l, t, r, b) in batch_boxes:
            patch = im.crop((l, t, r, b))
            patches.append(to_tensor(patch))
            coords.append((l, t))
        batch = torch.stack(patches, dim=0).to(device)
        if cfg["output_type"] == "token":
            toks = vit.forward_tokens(batch)
            B, T, D = toks.shape
            token_dim = D
            all_embeds.append(toks.reshape(B, T * D).cpu().numpy())
        else:
            feats = vit.forward_patch(batch)
            all_embeds.append(feats.cpu().numpy())
        if (i // batch_size) % 10 == 0 or i + batch_size >= total:
            pct = min(100, int(100 * (i + len(batch_boxes)) / total))
            elapsed = time.time() - t0
            print(f"Embedding tiles: {pct}% ({i + len(batch_boxes)}/{total})  elapsed {elapsed:.1f}s", end="\r")
    print()
    coords = np.array(coords, dtype=np.int32)
    embeds = np.concatenate(all_embeds, axis=0)
    return coords, embeds, token_dim


def infer_embeddings_image_stream(
    im,
    patch_size: int,
    stride: int,
    vit,
    cfg: Dict[str, Any],
    batch_size: int,
    device: Optional[torch.device] = None,
):
    if not HAS_DASK:
        raise RuntimeError("Streaming write requires dask; please install 'dask[array]'.")
    device = device or torch.device("cpu")
    vit = vit.to(device).eval()
    width, height = im.size
    boxes = generate_patch_grid(width, height, patch_size, stride)
    tissue_polys = cfg.get("_tissue_polys_image", None)
    if tissue_polys and bool(cfg.get("filter_to_tissue", True)) and HAS_SHAPELY:
        keep = []
        for (l, t, r, b) in boxes:
            cx = l + (r - l) * 0.5
            cy = t + (b - t) * 0.5
            pt = sg.Point(float(cx), float(cy))
            inside = any(poly.contains(pt) for poly in tissue_polys)
            if inside:
                keep.append((l, t, r, b))
        boxes = keep
    if not boxes:
        raise ValueError("No patches generated. Check image size, patch size, and stride.")
    total = len(boxes)
    print(f"Patching regular image → {total} tiles [stream]")
    coords: List[Tuple[int, int]] = []
    chunks: List[da.Array] = []
    to_tensor = vit.preprocess
    t0 = time.time()
    token_dim = None
    for i in range(0, total, batch_size):
        batch_boxes = boxes[i : i + batch_size]
        patches = []
        for (l, t, r, b) in batch_boxes:
            patch = im.crop((l, t, r, b))
            patches.append(to_tensor(patch))
            coords.append((l, t))
        batch = torch.stack(patches, dim=0).to(device)
        if cfg["output_type"] == "token":
            toks = vit.forward_tokens(batch)
            B, T, D = toks.shape
            token_dim = D
            arr = toks.reshape(B, T * D).cpu().numpy()
        else:
            feats = vit.forward_patch(batch)
            arr = feats.cpu().numpy()
        chunks.append(da.from_array(arr, chunks=arr.shape))
        if (i // batch_size) % 10 == 0 or i + batch_size >= total:
            pct = min(100, int(100 * (i + len(batch_boxes)) / total))
            elapsed = time.time() - t0
            print(f"Embedding tiles: {pct}% ({i + len(batch_boxes)}/{total})  elapsed {elapsed:.1f}s", end="\r")
    print()
    coords = np.array(coords, dtype=np.int32)
    embeds_da = da.concatenate(chunks, axis=0)
    return coords, embeds_da, token_dim