"""
Utility functions for image I/O across histotuner.
- H&E reader: handles regular RGB images (PIL) with tifffile fallback and whole-slide images via OpenSlide.
- Multiplexed reader: opens OME-TIFF/TIFF as Zarr-backed arrays when available, falling back to NumPy arrays.
- Unified HEInput wrapper and region reader to abstract over WSI vs regular images.
"""
from __future__ import annotations

from pathlib import Path
from typing import Optional, Sequence, Tuple, Union, Any

import numpy as np

# PIL for regular image reading
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
Image.MAX_IMAGE_PIXELS = None

# Optional heavy dependencies
HAS_OPENSLIDE = False
try:
    import openslide  # type: ignore
    HAS_OPENSLIDE = True
except Exception:
    HAS_OPENSLIDE = False

HAS_TIFFILE = False
try:
    import tifffile as tiff  # type: ignore
    HAS_TIFFILE = True
except Exception:
    HAS_TIFFILE = False

HAS_ZARR = False
try:
    import zarr  # type: ignore
    HAS_ZARR = True
except Exception:
    HAS_ZARR = False

# Whole-slide formats that should be routed to OpenSlide
WSI_SUFFIXES = {".svs", ".mrxs", ".ndpi", ".scn", ".svslide", ".bif"}


def is_wsi(path: str) -> bool:
    """Return True if the path looks like a WSI format handled by OpenSlide."""
    return Path(path).suffix.lower() in WSI_SUFFIXES


def read_wsi(path: str):
    """Open a whole-slide image via OpenSlide and return (slide_handle, avg_mpp)."""
    try:
        import openslide  # type: ignore
    except Exception as e:
        raise RuntimeError("OpenSlide is required for WSI inputs but is not available.") from e
    slide = openslide.OpenSlide(path)
    mpp_x = slide.properties.get("openslide.mpp-x")
    mpp_y = slide.properties.get("openslide.mpp-y")
    source_mpp: Optional[float] = None
    if mpp_x is not None and mpp_y is not None:
        try:
            source_mpp = (float(mpp_x) + float(mpp_y)) / 2.0
        except Exception:
            source_mpp = None
    return slide, source_mpp


def read_he_image(path: str) -> Image.Image:
    """Read an H&E image into a PIL Image (RGB), preferring tifffile for TIFF/OME-TIFF/SVS.
    Order of attempts:
    - If suffix indicates TIFF/OME-TIFF/SVS and tifffile is available, try tifffile first
    - Fallback to PIL.Image.open + convert("RGB")
    - Secondary fallback: try tifffile even for non-TIFF suffixes if available
    Normalizes dtype to uint8 and drops alpha when present.
    """
    suffix = Path(path).suffix.lower()
    tif_suffixes = {".svs", ".tif", ".tiff", ".ome.tif", ".ome.tiff"}
    if suffix in tif_suffixes and HAS_TIFFILE:
        try:
            arr = tiff.imread(path)
            if arr.ndim == 2:
                arr = np.stack([arr, arr, arr], axis=-1)
            elif arr.ndim == 3:
                # transpose if channel-first
                if arr.shape[0] in (3, 4) and arr.shape[-1] not in (3, 4):
                    arr = np.moveaxis(arr, 0, -1)
                # drop alpha if present
                if arr.shape[-1] == 4:
                    arr = arr[..., :3]
            # normalize to uint8 if needed
            if arr.dtype != np.uint8:
                maxv = float(np.iinfo(arr.dtype).max) if np.issubdtype(arr.dtype, np.integer) else float(arr.max() or 1.0)
                arr = (np.clip(arr.astype(np.float32), 0, maxv) / maxv * 255.0).astype(np.uint8)
            return Image.fromarray(arr).convert("RGB")
        except Exception:
            # fall through to PIL
            pass
    # Fallback: PIL
    try:
        return Image.open(path).convert("RGB")
    except FileNotFoundError:
        raise
    except Exception:
        # Secondary fallback: try tifffile for non-TIFF suffixes if available
        if HAS_TIFFILE:
            try:
                arr = tiff.imread(path)
                if arr.ndim == 2:
                    arr = np.stack([arr, arr, arr], axis=-1)
                elif arr.ndim == 3:
                    if arr.shape[0] in (3, 4) and arr.shape[-1] not in (3, 4):
                        arr = np.moveaxis(arr, 0, -1)
                    if arr.shape[-1] == 4:
                        arr = arr[..., :3]
                if arr.dtype != np.uint8:
                    maxv = float(np.iinfo(arr.dtype).max) if np.issubdtype(arr.dtype, np.integer) else float(arr.max() or 1.0)
                    arr = (np.clip(arr.astype(np.float32), 0, maxv) / maxv * 255.0).astype(np.uint8)
                return Image.fromarray(arr).convert("RGB")
            except Exception:
                pass
        # If all fallbacks fail, re-raise
        raise

# --- Restored helpers used by Spatial Feature Table ---

def image_as_zarr(input_path: str):
    """Open OME-TIFF/TIFF as a Zarr-backed array when possible.
    Returns (array_or_numpy, shape, axes_str). If tifffile/zarr not available or opening fails, returns (None, None, None).
    """
    if not HAS_TIFFILE:
        return None, None, None
    try:
        tf = tiff.TiffFile(input_path)
    except Exception:
        return None, None, None
    try:
        series = tf.series[0]
    except Exception:
        return None, None, None
    axes = getattr(series, "axes", None)
    # Prefer Zarr-backed for chunked reads
    if HAS_ZARR:
        try:
            zr = series.aszarr()
            arr = zarr.open(zr, mode="r")  # type: ignore
            return arr, getattr(arr, "shape", None), axes
        except Exception:
            pass
    # Fallback: NumPy array (may be memory heavy)
    try:
        data = series.asarray()
        return data, getattr(data, "shape", None), axes
    except Exception:
        try:
            data = tiff.imread(input_path)
            return data, getattr(data, "shape", None), axes
        except Exception:
            return None, None, None


def mask_as_zarr(mask_path: str):
    """Open a segmentation mask as Zarr-backed array if possible, else NumPy ndarray. Returns None on failure."""
    if not HAS_TIFFILE:
        return None
    try:
        tf = tiff.TiffFile(mask_path)
        series = tf.series[0]
    except Exception:
        return None
    if HAS_ZARR:
        try:
            zr = series.aszarr()
            arr = zarr.open(zr, mode="r")  # type: ignore
            return arr
        except Exception:
            pass
    try:
        data = series.asarray()
        return data
    except Exception:
        try:
            data = tiff.imread(mask_path)
            return data
        except Exception:
            return None


class HEInput:
    """Unified input wrapper for H&E sources.
    kind: 'wsi' | 'image'
    - For kind='wsi': slide_handle is set, source_mpp may be available, size is level 0 dims
    - For kind='image': image_obj is a PIL Image in RGB mode, size is image.size
    """
    def __init__(self, kind: str, slide_handle=None, image_obj: Image.Image | None = None, source_mpp: Optional[float] = None, size: Tuple[int, int] | None = None):
        self.kind = kind
        self.slide_handle = slide_handle
        self.image_obj = image_obj
        self.source_mpp = source_mpp
        self.size = size

    def __repr__(self) -> str:
        return f"HEInput(kind={self.kind}, size={self.size}, source_mpp={self.source_mpp})"


def open_he_input(src: Union[str, Image.Image, Any]) -> HEInput:
    """Auto-detect and open H&E input.
    - If src is a WSI path, use OpenSlide and return kind='wsi'
    - If src is an image path, return RGB PIL Image (kind='image')
    - If src is an OpenSlide handle, treat as 'wsi'
    - If src is a PIL Image, ensure mode RGB and treat as 'image'
    Raises ValueError for non-WSI inputs that cannot be coerced to 3-channel RGB.
    """
    # Path input
    if isinstance(src, str):
        if is_wsi(src):
            try:
                slide, source_mpp = read_wsi(src)
                W0, H0 = slide.level_dimensions[0]
                return HEInput(kind="wsi", slide_handle=slide, source_mpp=source_mpp, size=(W0, H0))
            except Exception:
                # Fallback: treat as regular image if OpenSlide cannot open
                pass
        # Non-WSI image or fallback from failed OpenSlide
        im = read_he_image(src)
        # Ensure RGB
        if im.mode != "RGB":
            try:
                im = im.convert("RGB")
            except Exception as e:
                raise ValueError(f"Image at {src} is not RGB and cannot be converted: {e}")
        return HEInput(kind="image", image_obj=im, source_mpp=None, size=im.size)

    # Object input
    try:
        import openslide  # type: ignore
        if isinstance(src, openslide.OpenSlide):
            W0, H0 = src.level_dimensions[0]
            # Attempt to get mpp from properties if available
            mpp_x = src.properties.get("openslide.mpp-x")
            mpp_y = src.properties.get("openslide.mpp-y")
            source_mpp = None
            try:
                if mpp_x and mpp_y:
                    source_mpp = (float(mpp_x) + float(mpp_y)) / 2.0
            except Exception:
                source_mpp = None
            return HEInput(kind="wsi", slide_handle=src, source_mpp=source_mpp, size=(W0, H0))
    except Exception:
        pass

    if isinstance(src, Image.Image):
        im = src
        if im.mode != "RGB":
            try:
                im = im.convert("RGB")
            except Exception as e:
                raise ValueError(f"Provided PIL Image is not RGB and cannot be converted: {e}")
        return HEInput(kind="image", image_obj=im, source_mpp=None, size=im.size)

    raise ValueError("Unsupported input type for H&E loading. Provide a path, OpenSlide handle, or PIL Image.")


def read_he_region(inp: HEInput, x0: int, y0: int, w: int, h: int) -> Image.Image:
    """Read a rectangular region from a HEInput as a PIL RGB image.
    - For kind='wsi': uses OpenSlide.read_region at level 0
    - For kind='image': crops the PIL image, clipping bounds to image size
    """
    if w <= 0 or h <= 0:
        raise ValueError("Requested region must have positive width and height.")
    if inp.kind == "wsi" and inp.slide_handle is not None:
        region = inp.slide_handle.read_region((int(x0), int(y0)), 0, (int(w), int(h))).convert("RGB")
        return region
    if inp.kind == "image" and inp.image_obj is not None:
        im = inp.image_obj
        W, H = im.size
        x0c = max(0, int(x0))
        y0c = max(0, int(y0))
        x1 = min(W, x0c + max(1, int(w)))
        y1 = min(H, y0c + max(1, int(h)))
        # Ensure non-empty crop; if degenerate, expand minimally
        if x1 <= x0c: x1 = min(W, x0c + 1)
        if y1 <= y0c: y1 = min(H, y0c + 1)
        return im.crop((x0c, y0c, x1, y1)).convert("RGB")
    raise ValueError("HEInput is not properly initialized for region reading.")