from __future__ import annotations

import argparse
import cv2
import geopandas as gpd
import numpy as np
from PIL import Image
from shapely import Polygon
from typing import Mapping, Sequence
from .utils import read_he_image as _read_he_image, read_wsi as _read_wsi, is_wsi as _is_wsi
import matplotlib.pyplot as plt
from pathlib import Path


# ------------------------------
# Minimal CV transform pipeline
# ------------------------------


class Transform:
    """Image Transform base class: Image -> Image"""

    params: dict = {}

    def __repr__(self):
        params_str = ", ".join([f"{k}={v}" for k, v in self.params.items()])
        return f"{self.__class__.__name__}({params_str})"

    def __call__(self, image: np.ndarray) -> np.ndarray:
        if isinstance(image, np.ndarray):
            image = image.astype(np.uint8)
            processed_image = self.apply(image)
            return processed_image.astype(np.uint8)
        raise TypeError(f"Input must be np.ndarray, got {type(image)}")

    def apply(self, image: np.ndarray) -> np.ndarray:
        raise NotImplementedError

    def set_params(self, **params):
        self.params.update(params)
        for k, v in params.items():
            if hasattr(self, k):
                raise ValueError(
                    f"Parameter {k} is not valid for {self.__class__.__name__}"
                )
            setattr(self, k, v)


class Compose(Transform):
    """Compose multiple transforms together."""

    def __init__(self, transforms):
        self.pipeline = transforms

    def apply(self, image: np.ndarray) -> np.ndarray:
        for p in self.pipeline:
            image = p(image)
        return image


class MedianBlur(Transform):
    """Median blur with odd kernel size."""

    def __init__(self, kernel_size: int = 5):
        if kernel_size % 2 == 0:
            raise ValueError("kernel_size must be an odd number")
        self.set_params(kernel_size=kernel_size)

    def apply(self, image: np.ndarray) -> np.ndarray:
        return cv2.medianBlur(image, ksize=self.params["kernel_size"])


class BinaryThreshold(Transform):
    """Binary thresholding; optionally Otsu and inverse."""

    def __init__(self, use_otsu: bool = True, threshold: int = 0, inverse: bool = False):
        self.type = cv2.THRESH_BINARY_INV if inverse else cv2.THRESH_BINARY
        if use_otsu:
            self.type += cv2.THRESH_OTSU
        self.set_params(use_otsu=use_otsu, threshold=threshold, inverse=inverse)

    def apply(self, image: np.ndarray) -> np.ndarray:
        if image.ndim > 2:
            raise ValueError("Must be greyscale image (H, W) or binary mask (H, W).")
        _, out = cv2.threshold(
            src=image,
            thresh=self.params["threshold"],
            maxval=255,
            type=self.type,
        )
        if self.params["inverse"]:
            out = 1 - out
        return out


class ArtifactFilterThreshold(Transform):
    """Artifact filter thresholding transform used by Lazyslide tissue detection."""

    def __init__(self, threshold: int = 0):
        self.set_params(threshold=threshold)

    def apply(self, image: np.ndarray) -> np.ndarray:
        # Expect RGB image
        red_channel = image[:, :, 0].astype(float)
        green_channel = image[:, :, 1].astype(float)
        blue_channel = image[:, :, 2].astype(float)

        red_to_green_mask = np.maximum(red_channel - green_channel, 0)
        blue_to_green_mask = np.maximum(blue_channel - green_channel, 0)

        tissue_heatmap = red_to_green_mask * blue_to_green_mask

        _, out = cv2.threshold(
            src=tissue_heatmap.astype(np.uint8),
            thresh=self.params["threshold"],
            maxval=255,
            type=(cv2.THRESH_BINARY + cv2.THRESH_OTSU),
        )
        return out


class MorphClose(Transform):
    """Morphological closing on binary mask."""

    def __init__(self, kernel_size: int = 5, n_iterations: int = 1):
        self.set_params(kernel_size=kernel_size, n_iterations=n_iterations)

    def apply(self, mask: np.ndarray) -> np.ndarray:
        ksize = self.params["kernel_size"]
        n_iter = self.params["n_iterations"]
        k = np.ones((ksize, ksize), dtype=np.uint8)
        out = cv2.morphologyEx(src=mask, kernel=k, op=cv2.MORPH_CLOSE, iterations=n_iter)
        return out


# ------------------------------
# BinaryMask and polygonization
# ------------------------------


class Mask:
    def __init__(
        self,
        mask: np.ndarray,
        prob_map: np.ndarray | None = None,
        class_names: Sequence[str] | Mapping[int, str] | None = None,
    ):
        self.mask = mask
        self.prob_map = prob_map
        if class_names is not None:
            if isinstance(class_names, Mapping):
                self.class_names = class_names
            elif isinstance(class_names, Sequence):
                self.class_names = {i: name for i, name in enumerate(class_names)}
            else:
                raise ValueError("class_name must be a Mapping or a Sequence.")
        else:
            self.class_names = None


class BinaryMask(Mask):
    def __init__(
        self,
        mask: np.ndarray,
        prob_map: np.ndarray | None = None,
        class_names: Sequence[str] | Mapping[int, str] | None = None,
    ):
        assert mask.ndim == 2, "Binary mask must be 2D."
        if prob_map is not None:
            assert prob_map.shape == mask.shape, (
                "Probability mask must have the same shape as the binary mask."
            )
        mask = np.asarray(mask > 0, dtype=np.uint8)
        super().__init__(mask, prob_map, class_names)

    def to_polygons(
        self,
        min_area: float = 0,
        min_hole_area: float = 0,
        detect_holes: bool = True,
        ignore_index: int | Sequence[int] | None = None,  # noqa: ARG002
    ) -> gpd.GeoDataFrame:
        return binary_mask_to_polygons_with_prob(
            self.mask,
            prob_map=self.prob_map,
            min_area=min_area,
            min_hole_area=min_hole_area,
            detect_holes=detect_holes,
        )

    # CamelCase wrapper for notebook/CLI consistency
    def toPolygons(
        self,
        minArea: float = 0,
        minHoleArea: float = 0,
        detectHoles: bool = True,
        ignoreIndex: int | Sequence[int] | None = None,  # noqa: ARG002
    ) -> gpd.GeoDataFrame:
        return self.to_polygons(
            min_area=minArea,
            min_hole_area=minHoleArea,
            detect_holes=detectHoles,
            ignore_index=ignoreIndex,
        )


# ------------------------------
# Polygonization helpers
# ------------------------------


def binary_mask_to_polygons(
    binary_mask: np.ndarray,
    min_area: float = 0,
    min_hole_area: float = 0,
    detect_holes: bool = True,
) -> list[Polygon]:
    """Convert binary mask to polygons, optionally detecting holes.

    min_area/min_hole_area are treated as absolute pixels if >=1, otherwise as a fraction of total image area.
    """
    if min_area < 1:
        min_area = int(min_area * binary_mask.size)
    if min_hole_area < 1:
        min_hole_area = int(min_hole_area * binary_mask.size)

    mode = cv2.RETR_CCOMP if detect_holes else cv2.RETR_EXTERNAL
    contours, hierarchy = cv2.findContours(
        binary_mask, mode=mode, method=cv2.CHAIN_APPROX_NONE
    )

    if hierarchy is None:
        return []
    elif not detect_holes:
        polys = []
        for cnt in contours:
            if cv2.contourArea(cnt) > min_area:
                cnt = np.squeeze(cnt, axis=1)
                if len(cnt) >= 4:
                    polys.append(Polygon(shell=cnt, holes=[]))
        return polys
    else:
        poly_ixs = []
        for i, (cnt, hier) in enumerate(zip(contours, hierarchy[0])):
            holes_ix = []
            if hier[3] == -1:
                area = cv2.contourArea(cnt)
                if area > min_area:
                    next_hole_index = hier[2]
                    while True:
                        if next_hole_index != -1:
                            next_hole = hierarchy[0][next_hole_index]
                            if cv2.contourArea(contours[next_hole_index]) > min_hole_area:
                                holes_ix.append(next_hole_index)
                            next_hole_index = next_hole[0]
                        else:
                            break
                    poly_ixs.append((i, holes_ix))

        polys = []
        for cnt_ix, holes_ixs in poly_ixs:
            polys.append(
                Polygon(
                    shell=np.squeeze(contours[cnt_ix], axis=1),
                    holes=[np.squeeze(contours[ix], axis=1) for ix in holes_ixs],
                )
            )
        return polys


def binary_mask_to_polygons_with_prob(
    binary_mask: np.ndarray,
    prob_map: np.ndarray | None = None,
    min_area: float = 0,
    min_hole_area: float = 0,
    detect_holes: bool = True,
) -> gpd.GeoDataFrame:
    """Convert binary mask to polygon and include probability information if provided."""
    polys = binary_mask_to_polygons(
        binary_mask,
        min_area=min_area,
        min_hole_area=min_hole_area,
        detect_holes=detect_holes,
    )

    if not polys:
        return gpd.GeoDataFrame(columns=["geometry", "prob"])

    data = []
    if prob_map is not None:
        is_classification = prob_map.ndim == 3
        for poly in polys:
            poly_mask = np.zeros_like(binary_mask, dtype=np.uint8)
            points = np.array(poly.exterior.coords, dtype=np.int32)
            cv2.drawContours(poly_mask, [points], -1, 1, thickness=cv2.FILLED)
            if detect_holes:
                for hole in poly.interiors:
                    hole_points = np.array(hole.coords, dtype=np.int32)
                    cv2.fillPoly(poly_mask, [hole_points], 0)

            if is_classification:
                masked_prob = prob_map * poly_mask
                prob = (
                    np.sum(masked_prob, axis=(1, 2)) / np.sum(poly_mask)
                    if np.sum(poly_mask) > 0
                    else 0
                )
                c = int(np.argmax(prob))
                prob = float(prob[c])
                data.append({"geometry": poly, "prob": prob, "class": c})
            else:
                masked_prob = prob_map * poly_mask
                prob = (
                    float(np.sum(masked_prob)) / float(np.sum(poly_mask))
                    if np.sum(poly_mask) > 0
                    else 0.0
                )
                data.append({"geometry": poly, "prob": prob})
    else:
        for poly in polys:
            data.append({"geometry": poly})

    return gpd.GeoDataFrame(data)


# ------------------------------
# Public API used by he_embedder
# ------------------------------


def _tissue_mask(
    image: np.ndarray,
    to_hsv: bool,
    filter_artifacts: bool = True,
    blur_ksize: int = 17,
    threshold: int | None = 7,
    morph_ksize: int = 7,
    morph_n_iter: int = 3,
) -> np.ndarray:
    """Build tissue mask using a Lazyslide-like pipeline without external deps."""
    if not filter_artifacts:
        if to_hsv:
            image = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)[:, :, 1]
        else:
            image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)

    if filter_artifacts:
        thresher = ArtifactFilterThreshold(threshold=threshold or 0)
    else:
        if threshold is None:
            thresher = BinaryThreshold(use_otsu=True)
        else:
            thresher = BinaryThreshold(use_otsu=False, threshold=int(threshold))

    c = Compose(
        [
            MedianBlur(kernel_size=int(blur_ksize)),
            thresher,
            MorphClose(kernel_size=int(morph_ksize), n_iterations=int(morph_n_iter)),
        ]
    )
    return c.apply(image)


# CamelCase public API wrappers

def tissueMask(
    image: np.ndarray,
    toHsv: bool = False,
    filterArtifacts: bool = True,
    blurKsize: int = 17,
    threshold: int | None = 7,
    morphKsize: int = 7,
    morphNIter: int = 3,
) -> np.ndarray:
    """CamelCase wrapper around _tissue_mask for notebook/CLI consistency."""
    return _tissue_mask(
        image=image,
        to_hsv=toHsv,
        filter_artifacts=filterArtifacts,
        blur_ksize=blurKsize,
        threshold=threshold,
        morph_ksize=morphKsize,
        morph_n_iter=morphNIter,
    )


def binaryMaskToPolygons(
    binaryMask: np.ndarray,
    minArea: float = 0,
    minHoleArea: float = 0,
    detectHoles: bool = True,
) -> gpd.GeoDataFrame:
    """CamelCase wrapper returning polygons (GeoDataFrame) from a binary mask."""
    return binary_mask_to_polygons_with_prob(
        binary_mask=binaryMask,
        prob_map=None,
        min_area=minArea,
        min_hole_area=minHoleArea,
        detect_holes=detectHoles,
    )


def binaryMaskToPolygonsWithProb(
    binaryMask: np.ndarray,
    probMap: np.ndarray,
    minArea: float = 0,
    minHoleArea: float = 0,
    detectHoles: bool = True,
) -> gpd.GeoDataFrame:
    """CamelCase wrapper returning polygons with probability from a binary mask and prob map."""
    return binary_mask_to_polygons_with_prob(
        binary_mask=binaryMask,
        prob_map=probMap,
        min_area=minArea,
        min_hole_area=minHoleArea,
        detect_holes=detectHoles,
    )


# CLI support for tissue mask
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate tissue mask from an image")
    parser.add_argument("--inputPath", type=str, required=True, help="Input image path")
    parser.add_argument("--outputPath", type=str, required=True, help="Output path for mask image (e.g., .png or .tiff)")
    parser.add_argument("--toHsv", action="store_true", help="Use HSV S-channel for thresholding when artifacts are not filtered")
    parser.add_argument("--filterArtifacts", action="store_true", default=True, help="Enable artifact filtering (default: True)")
    parser.add_argument("--noFilterArtifacts", action="store_false", dest="filterArtifacts", help="Disable artifact filtering")
    parser.add_argument("--blurKsize", type=int, default=17, help="Median blur kernel size (odd)")
    parser.add_argument("--threshold", type=int, default=7, help="Threshold value; set negative to use Otsu when artifacts are not filtered")
    parser.add_argument("--morphKsize", type=int, default=7, help="Morphological closing kernel size")
    parser.add_argument("--morphNIter", type=int, default=3, help="Morphological closing iterations")
    parser.add_argument("--minArea", type=float, default=0.0, help="Minimum area for polygons (fraction if <1, pixels if >=1)")
    parser.add_argument("--minHoleArea", type=float, default=0.0, help="Minimum hole area for polygons (fraction if <1, pixels if >=1)")
    parser.add_argument("--detectHoles", action="store_true", default=True, help="Detect holes in polygons (default: True)")
    parser.add_argument("--noDetectHoles", action="store_false", dest="detectHoles", help="Do not detect holes in polygons")
    parser.add_argument("--plotImage", action="store_true", default=True, help="Plot and save overlay of mask on image (default: True)")
    parser.add_argument("--noPlotImage", action="store_false", dest="plotImage", help="Disable plotting overlay")
    parser.add_argument("--plotThumbMaxSide", type=int, default=1024, help="Max side for plotted thumbnail overlay (default: 1024)")

    args = parser.parse_args()

    # Load image (RGB) using robust reader
    im = _read_he_image(args.inputPath)
    arr = np.array(im)

    # Map threshold negative value to None for Otsu when artifacts are not filtered
    thr = None if (args.threshold is not None and int(args.threshold) < 0) else args.threshold

    mask = tissueMask(
        image=arr,
        toHsv=bool(args.toHsv),
        filterArtifacts=bool(args.filterArtifacts),
        blurKsize=int(args.blurKsize),
        threshold=thr,
        morphKsize=int(args.morphKsize),
        morphNIter=int(args.morphNIter),
    )

    # Optionally compute polygons (not saved by default)
    # polys_gdf = binaryMaskToPolygons(mask, minArea=float(args.minArea), minHoleArea=float(args.minHoleArea), detectHoles=bool(args.detectHoles))

    # Save mask
    Image.fromarray(mask).save(args.outputPath)
    print(f"Saved tissue mask to: {args.outputPath}")

    # Plot/save overlay visualization if requested
    if bool(args.plotImage):
        # Build thumbnail for display to avoid high-resolution plotting
        thumb_max = int(args.plotThumbMaxSide)
        base_im = Image.fromarray(arr)
        base_im.thumbnail((thumb_max, thumb_max))
        arr_thumb = np.array(base_im)
        w_t, h_t = base_im.size
        mask_thumb = cv2.resize(mask, (w_t, h_t), interpolation=cv2.INTER_NEAREST)
        # Create RGBA overlay: red where mask>0 with alpha transparency
        overlay = np.zeros((h_t, w_t, 4), dtype=np.uint8)
        overlay[..., 0] = 255  # red channel
        overlay[..., 3] = (mask_thumb > 0).astype(np.uint8) * int(255 * 0.4)
        fig, ax = plt.subplots(figsize=(10, 10))
        ax.imshow(arr_thumb)
        ax.imshow(overlay)
        # Add polygon outlines over the thumbnail using GeoDataFrame boundary (high z-order for visibility)
        tgdf_thumb = BinaryMask(mask_thumb).to_polygons(detect_holes=True)
        tgdf_thumb.boundary.plot(ax=ax, color="black", linewidth=2.0, zorder=5)
        ax.set_axis_off()
        fig.tight_layout()
        out_path = Path(args.outputPath)
        overlay_path = out_path.parent / f"{out_path.stem}_overlay.png"
        fig.savefig(str(overlay_path), dpi=200, bbox_inches="tight", pad_inches=0)
        plt.close(fig)
        print(f"Saved overlay visualization to: {overlay_path}")


def generateTissueMask(
    inputPath: str,
    toHsv: bool = False,
    filterArtifacts: bool = True,
    blurKsize: int = 17,
    threshold: int | None = 7,
    morphKsize: int = 7,
    morphNIter: int = 3,
    plotImage: bool = True,
    plotThumbMaxSide: int = 1024,
) -> np.ndarray:
    """Convenience function that reads an image from path and returns the tissue mask.
    Uses histotuner.utils.read_he_image to load robustly (PIL/tifffile fallback).
    Optionally plots the overlay in interactive environments.
    """
    im = _read_he_image(inputPath)
    arr = np.array(im)
    mask = tissueMask(
        image=arr,
        toHsv=toHsv,
        filterArtifacts=filterArtifacts,
        blurKsize=blurKsize,
        threshold=threshold,
        morphKsize=morphKsize,
        morphNIter=morphNIter,
    )
    if plotImage:
        # Build thumbnail for display
        thumb_max = int(plotThumbMaxSide)
        base_im = Image.fromarray(arr)
        base_im.thumbnail((thumb_max, thumb_max))
        arr_thumb = np.array(base_im)
        w_t, h_t = base_im.size
        mask_thumb = cv2.resize(mask, (w_t, h_t), interpolation=cv2.INTER_NEAREST)
        # RGBA overlay
        overlay = np.zeros((h_t, w_t, 4), dtype=np.uint8)
        overlay[..., 0] = 255
        overlay[..., 3] = (mask_thumb > 0).astype(np.uint8) * int(255 * 0.4)
        fig, ax = plt.subplots(figsize=(10, 10))
        ax.imshow(arr_thumb)
        ax.imshow(overlay)
        # Add polygon outlines over the thumbnail using GeoDataFrame boundary (high z-order for visibility)
        tgdf_thumb = BinaryMask(mask_thumb).to_polygons(detect_holes=True)
        tgdf_thumb.boundary.plot(ax=ax, color="black", linewidth=2.0, zorder=5)
        ax.set_axis_off()
        fig.tight_layout()
        plt.show()
        plt.close(fig)
    return mask

# Thumbnail helpers and SpatialData overlay (moved from overlay_tissue_thumbnail)

def build_thumbnail_wsi(src_path: Path, max_side: int):
    slide, _source_mpp = _read_wsi(str(src_path))
    W0, H0 = slide.level_dimensions[0]
    thumb = slide.get_thumbnail((max_side, max_side)).convert("RGB")
    return thumb, (W0, H0), slide


def build_thumbnail_image(src_path: Path, max_side: int, target_size=None):
    im = _read_he_image(str(src_path))
    # If the polygons are in resampled coordinates, make the base image match
    if target_size is not None:
        w_res, h_res = int(target_size[0]), int(target_size[1])
        im = im.resize((w_res, h_res), resample=Image.BILINEAR)
    im.thumbnail((max_side, max_side))
    return im


def overlay_tissue(zarr_path: Path, max_side: int = 2048, save_png: Path | None = None, image_path: Path | None = None, target_mpp_override: float | None = None, image_obj: Image.Image | None = None):
    try:
        from spatialdata import read_zarr
    except Exception as e:
        raise RuntimeError("spatialdata is required to read .zarr outputs.") from e
    try:
        from shapely.affinity import scale as scale_geom
    except Exception as e:
        raise RuntimeError("shapely is required for geometry scaling.") from e

    sd = read_zarr(str(zarr_path))
    if "tissue" not in sd.shapes:
        raise RuntimeError("No 'tissue' layer found in SpatialData shapes.")

    tgdf = sd.shapes["tissue"]

    # Try reading provenance from zarr for tissue-only outputs
    props = {}
    tile_spec = {}
    prov_path = zarr_path / "provenance.json"
    if prov_path.exists():
        import json
        with open(prov_path, "r", encoding="utf-8") as f:
            prov = json.load(f)
        props = {
            "input_path": prov.get("input_path", None),
            "source_mpp": prov.get("source_mpp", None),
            "image_size_resampled": prov.get("image_size_resampled", None),
        }
        tile_spec = {"effective_mpp": prov.get("effective_mpp", None)}
    elif len(sd.tables) > 0:
        # Assume single table stored
        tbl = next(iter(sd.tables.values()))
        props = tbl.uns.get("slide_properties", {})
        tile_spec = tbl.uns.get("tile_spec", {})

    # Derive source image path or object
    src_path = image_path if image_path is not None else (Path(props.get("input_path")) if props.get("input_path") else None)
    # Decide WSI vs regular image via unified reader
    is_wsi = False
    if image_obj is not None:
        is_wsi = False
    elif src_path is not None:
        is_wsi = _is_wsi(str(src_path))
    else:
        raise RuntimeError("No input image provided. Pass imagePath or a PIL Image.")

    # Get source and target mpp
    source_mpp = props.get("source_mpp", None)
    target_mpp = target_mpp_override if target_mpp_override is not None else tile_spec.get("effective_mpp", None)

    if is_wsi:
        thumb, (W0, H0), slide = build_thumbnail_wsi(src_path, max_side)
        # If mpp not in provenance, read from OpenSlide properties
        if source_mpp is None:
            try:
                mppx = float(slide.properties.get("openslide.mpp-x", "nan"))
                mppy = float(slide.properties.get("openslide.mpp-y", "nan"))
                source_mpp = float((mppx + mppy) / 2.0) if (mppx == mppx and mppy == mppy) else None
            except Exception:
                source_mpp = None
        if target_mpp is None:
            target_mpp = source_mpp
        # Map: X_thumb = X_T * (source_mpp/target_mpp) * (thumb.width / W0)
        if source_mpp is None or target_mpp is None:
            # Fallback: treat polygons as level-0 pixels
            xfact = thumb.width / float(W0)
            yfact = thumb.height / float(H0)
        else:
            xfact = (thumb.width / float(W0)) * (float(source_mpp) / float(target_mpp))
            yfact = (thumb.height / float(H0)) * (float(source_mpp) / float(target_mpp))
        tgdf_plot = tgdf.copy()
        tgdf_plot["geometry"] = tgdf_plot["geometry"].apply(
            lambda g: scale_geom(g, xfact=xfact, yfact=yfact, origin=(0, 0))
        )
    else:
        # Regular image branch: polygons are in the resampled image pixel space
        target_size = props.get("image_size_resampled", None)
        if image_obj is not None:
            im = image_obj
            # Ensure RGB to avoid grayscale thumbnails
            try:
                if getattr(im, "mode", None) != "RGB":
                    im = im.convert("RGB")
            except Exception:
                pass
            if target_size is not None:
                w_res, h_res = int(target_size[0]), int(target_size[1])
                im = im.resize((w_res, h_res), resample=Image.BILINEAR)
            im.thumbnail((max_side, max_side))
            thumb = im
        else:
            thumb = build_thumbnail_image(src_path, max_side, target_size=target_size)
        if target_size is None:
            if image_obj is not None:
                w0, h0 = image_obj.size
            else:
                im0 = _read_he_image(str(src_path))
                w0, h0 = im0.size
        else:
            w0, h0 = int(target_size[0]), int(target_size[1])
        xfact = thumb.width / float(w0)
        yfact = thumb.height / float(h0)
        tgdf_plot = tgdf.copy()
        tgdf_plot["geometry"] = tgdf_plot["geometry"].apply(
            lambda g: scale_geom(g, xfact=xfact, yfact=yfact, origin=(0, 0))
        )

    fig, ax = plt.subplots(figsize=(10, 10))
    ax.imshow(thumb)
    tgdf_plot.boundary.plot(ax=ax, color="red", linewidth=1.0)
    ax.set_axis_off()
    fig.tight_layout()

    if save_png is not None:
        save_png.parent.mkdir(parents=True, exist_ok=True)
        fig.savefig(str(save_png), dpi=200, bbox_inches="tight", pad_inches=0)
        print(f"Saved overlay to {save_png}")
    else:
        plt.show()


# CamelCase public API for notebook + CLI

def overlayTissueThumbnail(
    zarrPath: str | Path,
    maxSide: int = 2048,
    savePng: str | Path | None = None,
    imagePath: str | Path | Image.Image | None = None,
    targetMppOverride: float | None = None,
):
    zarr_path = Path(zarrPath)
    save_png = Path(savePng) if (isinstance(savePng, (str, Path)) and savePng is not None) else None
    img_path = Path(imagePath) if (isinstance(imagePath, (str, Path)) and imagePath is not None) else None
    img_obj = imagePath if isinstance(imagePath, Image.Image) else None
    return overlay_tissue(
        zarr_path=zarr_path,
        max_side=maxSide,
        save_png=save_png,
        image_path=img_path,
        target_mpp_override=targetMppOverride,
        image_obj=img_obj,
    )

# Backward-compat alias
# Users who previously imported tissueMaskFromPath can still use it
# It will be removed in a future major version.
tissueMaskFromPath = generateTissueMask