"""
segment.py — Standalone Cellpose segmentation with tissue masking and patch stitching
-----------------------------------------------------------------------------------
Loads a large image (regular TIFF/PNG or WSI via OpenSlide), detects tissue regions,
segments only tissue patches using Cellpose, and stitches results back into a full-size
label image matching the input dimensions.

Notes:
- Designed as a standalone script; later can be integrated into the pipeline.
- For very large data, consider Cellpose's distributed module which operates on Zarr arrays:
  see https://cellpose.readthedocs.io/en/latest/distributed.html#distributed-cellpose-for-larger-than-memory-data
- Current implementation processes patches sequentially and assembles a single label mask in memory.

Usage:
  python segment.py --input path/to/image.tif --output out_mask.tiff \
    --patch-size 512 --stride 512 --pretrained-model cpsam

"""

import argparse
import sys
from pathlib import Path
from typing import Tuple, List, Optional
from types import SimpleNamespace

import numpy as np

try:
    from PIL import Image
except Exception as _e:
    print("[error] PIL is required: pip install pillow", file=sys.stderr)
    raise

# Optional libraries
HAS_OPENSLIDE = False
try:
    import openslide  # type: ignore
    HAS_OPENSLIDE = True
except Exception:
    HAS_OPENSLIDE = False

HAS_TIFFILE = False
try:
    import tifffile as tiff  # type: ignore
    HAS_TIFFILE = True
except Exception:
    HAS_TIFFILE = False

def _load_cellpose_models_quiet():
    """Import cellpose.models while suppressing stdout/stderr banners."""
    import contextlib, io, logging, sys as _sys
    _buf_out, _buf_err = io.StringIO(), io.StringIO()
    try:
        with contextlib.redirect_stdout(_buf_out), contextlib.redirect_stderr(_buf_err):
            from cellpose import models as _models  # type: ignore
        logging.getLogger("cellpose.models").setLevel(logging.ERROR)
        return _models
    except Exception as _e:
        print("[error] cellpose is required: pip install cellpose", file=_sys.stderr)
        raise

# Reuse tiling helper
from .image_patcher import generate_patch_grid
from histotuner.utils import read_he_image as _read_he_image, read_wsi as _read_wsi, is_wsi as _is_wsi


# (removed duplicate local is_wsi; using histotuner.utils.is_wsi)


def parse_args():
    ap = argparse.ArgumentParser(description="Cellpose segmentation with tissue masking and stitching")
    ap.add_argument("--inputPath", required=True, help="Path to input image (WSI or regular)")
    ap.add_argument("--outputPath", required=True, help="Path to output mask (TIFF)")
    ap.add_argument("--patchSize", type=int, default=512, help="Patch size in pixels at full resolution")
    ap.add_argument("--stride", type=int, default=512, help="Stride in pixels; use <patchSize> for non-overlapping patches")
    ap.add_argument("--modelType", default="cyto3", help="[deprecated in v4] legacy model type; ignored in v4")
    ap.add_argument("--pretrainedModel", default="cpsam", help="Cellpose weights to use (e.g., cpsam or custom path)")
    ap.add_argument("--diameter", type=float, default=None, help="Cellpose object diameter (optional)")
    ap.add_argument("--tissueThreshold", type=float, default=240.0, help="Mean RGB threshold below which pixel is tissue")
    ap.add_argument("--maxPatches", type=int, default=None, help="Process at most N patches (debug)")
    ap.add_argument("--batchSize", type=int, default=8, help="Number of 256x256 tiles to run simultaneously on GPU")
    ap.add_argument("--progressInterval", type=int, default=100, help="Print progress every N processed patches")
    ap.add_argument("--verbose", action="store_true", help="Print more detailed progress information")
    ap.add_argument("--tileOverlap", type=float, default=0.1, help="Internal Cellpose tile overlap within eval (0-1)")
    ap.add_argument("--mergeOverlaps", action="store_true", help="Merge labels across overlapping patch seams to reduce splits")
    ap.add_argument("--cellprobThreshold", type=float, default=None, help="Lower to detect more seeds (e.g., -3 to 0)")
    ap.add_argument("--flowThreshold", type=float, default=None, help="Lower to allow more masks (e.g., 0.1–0.4)")
    # normalize default True; --noNormalize flips to False
    ap.add_argument("--noNormalize", action="store_false", dest="normalize", help="Disable intensity normalization inside Cellpose eval (default: enabled)")
    ap.set_defaults(normalize=True)
    # New inputs: SpatialData tissue polygons or tissue mask file
    ap.add_argument("--tissueZarr", default=None, help="Optional SpatialData .zarr path containing shapes['tissue'] for gating")
    ap.add_argument("--regionName", default="tissue", help="Shapes region name in SpatialData (default 'tissue')")
    ap.add_argument("--tissueMaskPath", default=None, help="Optional tissue mask image path (TIFF/PNG). If provided, used for gating")
    return ap.parse_args()


def compute_tissue_mask_rgb(im: Image.Image, mean_thresh: float = 240.0) -> np.ndarray:
    """Simple tissue detector: marks tissue where mean RGB is below a threshold.
    Returns a boolean mask of shape (H, W)."""
    rgb = np.asarray(im.convert("RGB"))
    gray_mean = rgb.mean(axis=2)
    mask = gray_mean < float(mean_thresh)
    # Optional: remove tiny speckles via area thresholding could be added here
    return mask


def segment_patch(model, patch_rgb: np.ndarray, diameter: Optional[float], batch_size: int = 8, tile_overlap: float = 0.1, cellprob_threshold: Optional[float] = None, flow_threshold: Optional[float] = None, normalize: bool = True):
    """Run Cellpose segmentation on a single RGB patch. Returns label mask (H, W)."""
    import inspect
    sig = inspect.signature(model.eval)
    kwargs = {"diameter": diameter, "batch_size": batch_size}
    if "tile_overlap" in sig.parameters:
        kwargs["tile_overlap"] = tile_overlap
    if cellprob_threshold is not None and "cellprob_threshold" in sig.parameters:
        kwargs["cellprob_threshold"] = cellprob_threshold
    if flow_threshold is not None and "flow_threshold" in sig.parameters:
        kwargs["flow_threshold"] = flow_threshold
    if "normalize" in sig.parameters:
        kwargs["normalize"] = bool(normalize)
    if "channel_axis" in sig.parameters:
        kwargs["channel_axis"] = 2  # HWC
    elif "channels" in sig.parameters:
        kwargs["channels"] = [0, 0]  # fallback for older versions
    masks, flows, styles = model.eval([patch_rgb], **kwargs)
    # model.eval returns list outputs when given list inputs
    mask = masks[0] if isinstance(masks, list) else masks
    return mask.astype(np.uint32, copy=False)


def _merge_overlap_labels(roi: np.ndarray, remapped: np.ndarray, overlap_pixels: int, min_pixels: int = 50) -> np.ndarray:
    """Merge labels only in the geometric overlap bands to avoid collapsing labels.
    roi: existing global labels in the output patch window.
    remapped: new patch labels already offset to global space.
    overlap_pixels: expected overlap width/height (patch_size - stride). If <=0, returns remapped unchanged.
    min_pixels: minimum overlapping pixel count required to consider a label merge.
    """
    if overlap_pixels <= 0:
        return remapped
    h, w = remapped.shape
    band = np.zeros((h, w), dtype=bool)
    # consider overlaps with previously processed patches (top and left neighbors in row-major order)
    band[:, :overlap_pixels] = True  # left band
    band[:overlap_pixels, :] = True  # top band
    overlap_mask = band & (roi > 0) & (remapped > 0)
    if not overlap_mask.any():
        return remapped
    from collections import Counter
    pairs = np.column_stack((roi[overlap_mask], remapped[overlap_mask]))
    mapping = {}
    for new_label in np.unique(pairs[:, 1]):
        existing = pairs[pairs[:, 1] == new_label, 0]
        if existing.size >= min_pixels:
            target = Counter(existing).most_common(1)[0][0]
            mapping[int(new_label)] = int(target)
    if not mapping:
        return remapped
    remapped_out = remapped.copy()
    for new, old in mapping.items():
        remapped_out[remapped_out == new] = old
    return remapped_out


def run_on_regular_image(args):
    print(f"[segment] Loading image: {args.inputPath}")
    im = _read_he_image(args.inputPath)
    width, height = im.size
    print(f"[segment] Image size: {width}x{height}")
    # Fast path: if threshold disables gating and no polygons/mask provided, segment all patches
    gate_all = (args.tissueThreshold is not None and float(args.tissueThreshold) >= 255.0) and (not args.tissueZarr) and (not args.tissueMaskPath)
    if gate_all:
        print("[segment] Tissue gating disabled; segmenting all patches (threshold>=255 and no polygons/mask).")
        boxes = generate_patch_grid(width, height, args.patchSize, args.stride)
        print(f"[segment] Generated {len(boxes)} patches (patch={args.patchSize}, stride={args.stride})")
        tissue_patches = len(boxes)
    else:
        # Build tissue gating: prefer SpatialData polygons, then mask, else internal detection
        center_in_tissue = None
        if args.tissueZarr:
            try:
                from spatialdata import read_zarr
                import shapely.geometry as sg
                from shapely.prepared import prep
                from shapely.affinity import scale as scale_geom
                from shapely.ops import unary_union
                sd = read_zarr(str(args.tissueZarr))
                if args.regionName not in sd.shapes:
                    raise RuntimeError(f"No shapes['{args.regionName}'] found in SpatialData")
                tgdf = sd.shapes[args.regionName]
                geoms = list(tgdf["geometry"].values)
                # Read provenance for potential resample scaling
                props = {}
                if len(sd.tables) > 0:
                    tbl = next(iter(sd.tables.values()))
                    props = tbl.uns.get("slide_properties", {})
                    tile_spec = tbl.uns.get("tile_spec", {})
                else:
                    # tissue-only path
                    import json, os
                    prov_path = os.path.join(str(args.tissueZarr), "provenance.json")
                    if os.path.exists(prov_path):
                        with open(prov_path, "r", encoding="utf-8") as f:
                            prov = json.load(f)
                        props = {
                            "image_size_resampled": prov.get("image_size_resampled", None),
                        }
                target_size = props.get("image_size_resampled", None)
                if target_size:
                    xfact = float(width) / float(target_size[0])
                    yfact = float(height) / float(target_size[1])
                    geoms = [scale_geom(g, xfact=xfact, yfact=yfact, origin=(0, 0)) for g in geoms]
                union = unary_union(geoms)
                prepared = prep(union)
                center_in_tissue = lambda cx, cy: prepared.contains(sg.Point(float(cx), float(cy)))
                print("[segment] Using tissue polygons from SpatialData for gating.")
            except Exception as e:
                print(f"[segment] Failed to use SpatialData polygons: {e}; falling back.", file=sys.stderr)
                center_in_tissue = None
        tissue_mask = None
        if center_in_tissue is None and args.tissueMaskPath:
            try:
                from PIL import Image as _PILImage
                mask_im = _PILImage.open(args.tissueMaskPath)
                if mask_im.size != (width, height):
                    mask_im = mask_im.resize((width, height), resample=_PILImage.NEAREST)
                tissue_mask = (np.array(mask_im) > 0)
                print("[segment] Using provided tissue mask for gating.")
            except Exception as e:
                print(f"[segment] Failed to load tissue mask: {e}; falling back.", file=sys.stderr)
                tissue_mask = None
        if center_in_tissue is None and tissue_mask is None:
            print(f"[segment] Detecting tissue (mean RGB < {args.tissueThreshold})")
            tissue_mask = compute_tissue_mask_rgb(im, mean_thresh=args.tissueThreshold)
        boxes = generate_patch_grid(width, height, args.patchSize, args.stride)
        print(f"[segment] Generated {len(boxes)} patches (patch={args.patchSize}, stride={args.stride})")
        # Count how many patch centers lie within tissue
        tissue_patches = 0
        for (l, t, r, b) in boxes:
            cx = int(l + (r - l) * 0.5)
            cy = int(t + (b - t) * 0.5)
            if center_in_tissue is not None:
                if center_in_tissue(cx, cy):
                    tissue_patches += 1
            else:
                if 0 <= cx < width and 0 <= cy < height and tissue_mask[cy, cx]:
                    tissue_patches += 1
        print(f"[segment] Tissue patches: {tissue_patches}/{len(boxes)} ({(tissue_patches/len(boxes)*100.0 if len(boxes)>0 else 0):.2f}%)")
        if tissue_patches == 0:
            print("[segment] No tissue patches detected; writing empty mask and exiting.")
            out = np.zeros((height, width), dtype=np.uint32)
            if HAS_TIFFILE:
                tiff.imwrite(args.outputPath, out, dtype=out.dtype)
            else:
                Image.fromarray(out.astype(np.uint16)).save(args.outputPath)
            return
    # Select model class compatible with installed cellpose

    models = _load_cellpose_models_quiet()
    ModelCls = getattr(models, "Cellpose", None) or getattr(models, "CellposeModel", None)
    if ModelCls is None:
        raise RuntimeError("Installed cellpose does not expose Cellpose/CellposeModel")
    # Force CUDA device if available for clarity (Cellpose v4 expects torch.device)
    import torch, inspect
    use_cuda = torch.cuda.is_available()
    device_obj = torch.device("cuda") if use_cuda else torch.device("cpu")
    sig = inspect.signature(ModelCls.__init__)
    kwargs = {}
    if "gpu" in sig.parameters:
        kwargs["gpu"] = use_cuda
    if "device" in sig.parameters:
        kwargs["device"] = device_obj
    if "pretrained_model" in sig.parameters and args.pretrainedModel:
        kwargs["pretrained_model"] = args.pretrainedModel
    # Suppress verbose version banners printed during model initialization
    import contextlib, io
    _buf_out, _buf_err = io.StringIO(), io.StringIO()
    with contextlib.redirect_stdout(_buf_out), contextlib.redirect_stderr(_buf_err):
        model = ModelCls(**kwargs)
    print(f"[segment] Using device: {device_obj.type} (torch.cuda.is_available()={use_cuda})")

    out = np.zeros((height, width), dtype=np.uint32)
    next_label = 1
    processed = 0
    segmented_patches = 0
    import time, sys as _sys
    t0 = time.time()
    for (l, t, r, b) in boxes:
        if args.maxPatches is not None and processed >= args.maxPatches:
            break
        cx = l + (r - l) * 0.5
        cy = t + (b - t) * 0.5
        # Gate: skip if outside tissue unless gating disabled
        if not gate_all:
            if 'center_in_tissue' in locals() and center_in_tissue is not None:
                if not center_in_tissue(cx, cy):
                    continue
            else:
                if not tissue_mask[int(cy), int(cx)]:
                    continue
        patch = np.asarray(im.crop((l, t, r, b)))
        mask = segment_patch(
            model,
            patch,
            args.diameter,
            args.batchSize,
            args.tileOverlap,
            args.cellprobThreshold,
            args.flowThreshold,
            args.normalize,
        )
        if mask.max() > 0:
            # Remap local labels to global IDs
            max_label = int(mask.max())
            mask_nonzero = mask > 0
            # Offset labels to ensure uniqueness
            remapped = mask.copy()
            remapped[mask_nonzero] = remapped[mask_nonzero] + (next_label - 1)
            # Write into output; preserve existing labels on overlaps
            roi = out[t:b, l:r]
            if args.mergeOverlaps:
                overlap_pixels = max(0, args.patchSize - args.stride)
                remapped = _merge_overlap_labels(roi, remapped, overlap_pixels)
            write_where = (roi == 0) & mask_nonzero
            roi[write_where] = remapped[write_where]
            out[t:b, l:r] = roi
            next_label += max_label
            segmented_patches += 1
        processed += 1
        if args.verbose:
            total = tissue_patches if not gate_all else len(boxes)
            pct = min(100, int(100 * processed / total))
            elapsed = time.time() - t0
            _sys.stdout.write(f"\r[segment] Processing tiles: {pct}% ({processed}/{total})  elapsed {elapsed:.1f}s  seg={segmented_patches} labels={next_label-1}")
            _sys.stdout.flush()
        elif processed % max(1, args.progressInterval) == 0:
            print(f"[segment] Progress: processed={processed}/{(tissue_patches if not gate_all else len(boxes))}, segmented_patches={segmented_patches}, next_global_label={next_label}")

    if args.verbose:
        _sys.stdout.write("\n")
    # Save output mask
    if HAS_TIFFILE:
        tiff.imwrite(args.outputPath, out, dtype=out.dtype)
    else:
        Image.fromarray(out.astype(np.uint16)).save(args.outputPath)
    print(f"[segment] Saved segmentation mask: {args.outputPath}  shape={out.shape}  labels={next_label-1}")


def run_on_wsi(args):
    if not HAS_OPENSLIDE:
        raise RuntimeError("OpenSlide is required to segment WSIs. pip install openslide-python")
    print(f"[segment] Opening WSI: {args.inputPath}")
    slide, _source_mpp = _read_wsi(args.inputPath)
    W0, H0 = slide.level_dimensions[0]
    print(f"[segment] WSI level-0 size: {W0}x{H0}")
    # Fast path: if threshold disables gating and no polygons/mask provided, segment all patches
    gate_all = (args.tissueThreshold is not None and float(args.tissueThreshold) >= 255.0) and (not args.tissueZarr) and (not args.tissueMaskPath)
    if gate_all:
        boxes = generate_patch_grid(W0, H0, args.patchSize, args.stride)
        print(f"[segment] Tissue gating disabled; segmenting all patches (threshold>=255 and no polygons/mask).")
        print(f"[segment] Generated {len(boxes)} patches (patch={args.patchSize}, stride={args.stride})")
        tissue_patches = len(boxes)
    else:
        # Build gating: prefer SpatialData polygons, then mask; else thumbnail RGB detection
        center_in_tissue = None
        if args.tissueZarr:
            try:
                from spatialdata import read_zarr
                import shapely.geometry as sg
                from shapely.prepared import prep
                from shapely.affinity import scale as scale_geom
                from shapely.ops import unary_union
                sd = read_zarr(str(args.tissueZarr))
                if args.regionName not in sd.shapes:
                    raise RuntimeError(f"No shapes['{args.regionName}'] found in SpatialData")
                tgdf = sd.shapes[args.regionName]
                geoms = list(tgdf["geometry"].values)
                # Read provenance for potential mpp scaling
                props = {}
                tile_spec = {}
                if len(sd.tables) > 0:
                    tbl = next(iter(sd.tables.values()))
                    props = tbl.uns.get("slide_properties", {})
                    tile_spec = tbl.uns.get("tile_spec", {})
                else:
                    import json, os
                    prov_path = os.path.join(str(args.tissueZarr), "provenance.json")
                    if os.path.exists(prov_path):
                        with open(prov_path, "r", encoding="utf-8") as f:
                            prov = json.load(f)
                        props = {
                            "source_mpp": prov.get("source_mpp", None),
                            "effective_mpp": prov.get("effective_mpp", None),
                        }
                source_mpp = props.get("source_mpp", None)
                target_mpp = tile_spec.get("effective_mpp", None) or props.get("effective_mpp", None)
                if source_mpp is not None and target_mpp is not None and float(target_mpp) != 0.0:
                    # Scale polygons from target-mpp virtual space to level-0 pixel space
                    xfact = float(source_mpp) / float(target_mpp)
                    yfact = float(source_mpp) / float(target_mpp)
                    geoms = [scale_geom(g, xfact=xfact, yfact=yfact, origin=(0, 0)) for g in geoms]
                union = unary_union(geoms)
                prepared = prep(union)
                center_in_tissue = lambda cx, cy: prepared.contains(sg.Point(float(cx), float(cy)))
                print("[segment] Using tissue polygons from SpatialData for gating (WSI level-0).")
            except Exception as e:
                print(f"[segment] Failed to use SpatialData polygons: {e}; falling back.", file=sys.stderr)
                center_in_tissue = None
        tissue_thumb = None
        thumb_w = min(2048, W0)
        scale = W0 / float(thumb_w)
        thumb_h = int(H0 / scale)
        if center_in_tissue is None and args.tissueMaskPath:
            try:
                from PIL import Image as _PILImage
                mask_im = _PILImage.open(args.tissueMaskPath).convert("L")
                mw, mh = mask_im.size
                if (mw, mh) == (W0, H0):
                    tissue_thumb = np.array(mask_im.resize((thumb_w, thumb_h), resample=_PILImage.NEAREST)) > 0
                    print("[segment] Using provided tissue mask (level-0 scaled to thumbnail) for gating.")
                else:
                    tissue_thumb = np.array(mask_im.resize((thumb_w, thumb_h), resample=_PILImage.NEAREST)) > 0
                    print("[segment] Using provided tissue mask (resized to thumbnail) for gating.")
            except Exception as e:
                print(f"[segment] Failed to load tissue mask: {e}; falling back.", file=sys.stderr)
                tissue_thumb = None
        if center_in_tissue is None and tissue_thumb is None:
            print(f"[segment] Building thumbnail for tissue detection: {thumb_w}x{thumb_h} (scale={scale:.3f})")
            thumb = slide.get_thumbnail((thumb_w, thumb_h)).convert("RGB")
            print(f"[segment] Detecting tissue on thumbnail (mean RGB < {args.tissueThreshold})")
            tissue_thumb = compute_tissue_mask_rgb(thumb, mean_thresh=args.tissueThreshold)
        coverage = float((tissue_thumb.mean() if center_in_tissue is None else 1.0)) * 100.0 if tissue_thumb is not None else 0.0
        if center_in_tissue is None:
            print(f"[segment] Tissue detected: {'yes' if coverage>0 else 'no'} (coverage={coverage:.2f}%)")
        # Prepare tiling in level-0 pixels
        boxes = generate_patch_grid(W0, H0, args.patchSize, args.stride)
        print(f"[segment] Generated {len(boxes)} patches (patch={args.patchSize}, stride={args.stride})")
        # Count how many patch centers lie within tissue
        tissue_patches = 0
        for (l, t, r, b) in boxes:
            cx = l + (r - l) * 0.5
            cy = t + (b - t) * 0.5
            if center_in_tissue is not None:
                if center_in_tissue(cx, cy):
                    tissue_patches += 1
            else:
                cx_th = int(cx / scale)
                cy_th = int(cy / scale)
                if 0 <= cx_th < thumb_w and 0 <= cy_th < thumb_h and tissue_thumb[cy_th, cx_th]:
                    tissue_patches += 1
        print(f"[segment] Tissue patches: {tissue_patches}/{len(boxes)} ({(tissue_patches/len(boxes)*100.0 if len(boxes)>0 else 0):.2f}%)")
        if tissue_patches == 0:
            print("[segment] No tissue patches detected; writing empty mask and exiting.")
            out = np.zeros((H0, W0), dtype=np.uint32)
            if HAS_TIFFILE:
                tiff.imwrite(args.outputPath, out, dtype=out.dtype)
            else:
                Image.fromarray(out.astype(np.uint16)).save(args.outputPath)
            slide.close()
            return
    # Select model class compatible with installed cellpose
    models = _load_cellpose_models_quiet()
    ModelCls = getattr(models, "Cellpose", None) or getattr(models, "CellposeModel", None)
    if ModelCls is None:
        raise RuntimeError("Installed cellpose does not expose Cellpose/CellposeModel")
    # Force CUDA device if available for clarity (Cellpose v4 expects torch.device)
    import torch, inspect
    use_cuda = torch.cuda.is_available()
    device_obj = torch.device("cuda") if use_cuda else torch.device("cpu")
    sig = inspect.signature(ModelCls.__init__)
    kwargs = {}
    if "gpu" in sig.parameters:
        kwargs["gpu"] = use_cuda
    if "device" in sig.parameters:
        kwargs["device"] = device_obj
    if "pretrained_model" in sig.parameters and args.pretrainedModel:
        kwargs["pretrained_model"] = args.pretrainedModel
    # Suppress verbose version banners printed during model initialization
    import contextlib, io
    _buf_out, _buf_err = io.StringIO(), io.StringIO()
    with contextlib.redirect_stdout(_buf_out), contextlib.redirect_stderr(_buf_err):
        model = ModelCls(**kwargs)
    print(f"[segment] Using device: {device_obj.type} (torch.cuda.is_available()={use_cuda})")

    out = np.zeros((H0, W0), dtype=np.uint32)
    next_label = 1
    processed = 0
    segmented_patches = 0
    read_size = (args.patchSize, args.patchSize)
    print("[segment] Running segmentation...")
    import time, sys as _sys
    t0 = time.time()
    for (l, t, r, b) in boxes:
        if args.maxPatches is not None and processed >= args.maxPatches:
            break
        cx = l + (r - l) * 0.5
        cy = t + (b - t) * 0.5
        # Gate using polygons/mask unless gating disabled
        if not gate_all:
            if 'center_in_tissue' in locals() and center_in_tissue is not None:
                if not center_in_tissue(cx, cy):
                    continue
            else:
                cx_th = int(cx / scale)
                cy_th = int(cy / scale)
                if cx_th < 0 or cy_th < 0 or cx_th >= thumb_w or cy_th >= thumb_h:
                    continue
                if not tissue_thumb[cy_th, cx_th]:
                    continue
        patch = slide.read_region((int(l), int(t)), 0, read_size).convert("RGB")
        patch_np = np.asarray(patch)
        mask = segment_patch(
            model,
            patch_np,
            args.diameter,
            args.batchSize,
            args.tileOverlap,
            args.cellprobThreshold,
            args.flowThreshold,
            args.normalize,
        )
        if mask.max() > 0:
            max_label = int(mask.max())
            mask_nonzero = mask > 0
            remapped = mask.copy()
            remapped[mask_nonzero] = remapped[mask_nonzero] + (next_label - 1)
            roi = out[t:b, l:r]
            if args.mergeOverlaps:
                overlap_pixels = max(0, args.patchSize - args.stride)
                remapped = _merge_overlap_labels(roi, remapped, overlap_pixels)
            write_where = (roi == 0) & mask_nonzero
            roi[write_where] = remapped[write_where]
            out[t:b, l:r] = roi
            next_label += max_label
            segmented_patches += 1
        processed += 1
        if args.verbose:
            total = tissue_patches if tissue_patches > 0 else len(boxes)
            pct = min(100, int(100 * processed / total))
            elapsed = time.time() - t0
            _sys.stdout.write(f"\r[segment] Processing tiles: {pct}% ({processed}/{total})  elapsed {elapsed:.1f}s  seg={segmented_patches} labels={next_label-1}")
            _sys.stdout.flush()
        elif processed % max(1, args.progressInterval) == 0:
            print(f"[segment] Progress: processed={processed}/{(tissue_patches if not gate_all else len(boxes))}, segmented_patches={segmented_patches}, next_global_label={next_label}")

    if args.verbose:
        _sys.stdout.write("\n")
    slide.close()
    # Save output mask
    if HAS_TIFFILE:
        tiff.imwrite(args.outputPath, out, dtype=out.dtype)
    else:
        Image.fromarray(out.astype(np.uint16)).save(args.outputPath)
    print(f"[segment] Saved segmentation mask: {args.outputPath}  shape={out.shape}  labels={next_label-1}")


def segmentImage(
    inputPath: str,
    outputPath: str,
    patchSize: int = 512,
    stride: int = 512,
    pretrainedModel: str = "cpsam",
    diameter: Optional[float] = None,
    tissueThreshold: float = 240.0,
    maxPatches: Optional[int] = None,
    batchSize: int = 8,
    progressInterval: int = 100,
    verbose: bool = False,
    tileOverlap: float = 0.1,
    mergeOverlaps: bool = False,
    cellprobThreshold: Optional[float] = None,
    flowThreshold: Optional[float] = None,
    normalize: bool = True,
    modelType: str = "cyto3",
    tissueZarr: Optional[str] = None,
    tissueMaskPath: Optional[str] = None,
    regionName: str = "tissue",
) -> str:
    """Notebook-friendly API to run segmentation with camelCase params."""
    args = SimpleNamespace(
        inputPath=inputPath,
        outputPath=outputPath,
        patchSize=patchSize,
        stride=stride,
        modelType=modelType,
        pretrainedModel=pretrainedModel,
        diameter=diameter,
        tissueThreshold=tissueThreshold,
        maxPatches=maxPatches,
        batchSize=batchSize,
        progressInterval=progressInterval,
        verbose=verbose,
        tileOverlap=tileOverlap,
        mergeOverlaps=mergeOverlaps,
        cellprobThreshold=cellprobThreshold,
        flowThreshold=flowThreshold,
        normalize=normalize,
        tissueZarr=tissueZarr,
        tissueMaskPath=tissueMaskPath,
        regionName=regionName,
    )
    if _is_wsi(str(inputPath)):
        run_on_wsi(args)
    else:
        run_on_regular_image(args)
    return outputPath


def main():
    args = parse_args()
    # Route based on input type
    segmentImage(
        inputPath=args.inputPath,
        outputPath=args.outputPath,
        patchSize=args.patchSize,
        stride=args.stride,
        pretrainedModel=args.pretrainedModel,
        diameter=args.diameter,
        tissueThreshold=args.tissueThreshold,
        maxPatches=args.maxPatches,
        batchSize=args.batchSize,
        progressInterval=args.progressInterval,
        verbose=args.verbose,
        tileOverlap=args.tileOverlap,
        mergeOverlaps=args.mergeOverlaps,
        cellprobThreshold=args.cellprobThreshold,
        flowThreshold=args.flowThreshold,
        normalize=args.normalize,
        modelType=args.modelType,
    )


if __name__ == "__main__":
    main()