"""
spatialfeaturetable.py — Per-cell feature table from multiplexed images and mask
-------------------------------------------------------------------------------
Reads a multiplexed OME-TIFF/TIFF along with a segmentation mask and produces a
cell-level feature table (mean/median per channel) with centroids, writing to
AnnData (.h5ad) and optionally SpatialData (.zarr) if available.

Memory-efficient approach:
- Reads image one channel at a time via OME-Zarr or dask-backed chunks
- Streams over chunks to accumulate per-cell statistics
- Default statistic is mean; median uses streaming histogram approximation

Outputs:
- AnnData with X = features (n_cells × n_channels), dtype configurable (default float16)
- obs: X_centroid, Y_centroid, CellId, imageid
- var: marker names resolved from OME, CSV, or fallback marker-1..N
- If spatialdata is available, writes a SpatialData with a table only

CLI example:
python spatialfeaturetable.py --input image.ome.tiff --mask mask.tiff \
  --stat mean --markers-csv markers.csv --output out_dir
"""

from __future__ import annotations

import argparse
import csv
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Any
import sys
import time

import numpy as np
from histotuner.utils import image_as_zarr as _image_as_zarr_utils, mask_as_zarr as _mask_as_zarr_utils

HAS_TIFFFILE = False
try:
    import tifffile as tiff  # type: ignore
    HAS_TIFFFILE = True
except Exception:
    HAS_TIFFFILE = False


HAS_ANNDATA = False
try:
    import anndata as ad  # type: ignore
    HAS_ANNDATA = True
except Exception:
    HAS_ANNDATA = False

HAS_SPATIALDATA = False
try:
    import spatialdata as sdata  # type: ignore
    from spatialdata import SpatialData  # type: ignore
    HAS_SPATIALDATA = True
except Exception:
    HAS_SPATIALDATA = False

# Optional geometry deps for shapes writing
HAS_SHAPELY = False
try:
    import shapely.geometry as sg  # type: ignore
    from shapely.ops import unary_union  # type: ignore
    import geopandas as gpd  # type: ignore
    from spatialdata.models import ShapesModel, TableModel  # type: ignore
    HAS_SHAPELY = True
except Exception:
    HAS_SHAPELY = False

# Modular writers (robust import for package and relative)
# Modular writers (robust import for package and script execution)
try:
    from .spatial_writer import write_spatialdata_table, update_spatialdata_table  # type: ignore
except Exception:
    try:
        from histotuner.spatial_writer import write_spatialdata_table, update_spatialdata_table  # type: ignore
    except Exception:
        try:
            from spatial_writer import write_spatialdata_table, update_spatialdata_table  # type: ignore
        except Exception:
            write_spatialdata_table = None
            update_spatialdata_table = None

def _warn(msg: str) -> None:
    print(f"[warn] {msg}")


def _info(msg: str) -> None:
    print(f"[info] {msg}")


def _read_markers_csv(path: Optional[str]) -> Optional[List[str]]:
    if not path:
        return None
    p = Path(path)
    if not p.exists():
        _warn(f"Markers CSV not found: {p}")
        return None
    names: List[str] = []
    try:
        with p.open("r", newline="") as f:
            reader = csv.DictReader(f)
            if "marker_name" not in reader.fieldnames:
                _warn("Markers CSV missing 'marker_name' column; ignoring.")
                return None
            for row in reader:
                nm = str(row.get("marker_name", "")).strip()
                if nm:
                    names.append(nm)
    except Exception as e:
        _warn(f"Failed reading markers CSV: {e}")
        return None
    return names if names else None


def _parse_ome_channel_names(tf: "tiff.TiffFile") -> Optional[List[str]]:
    # Best-effort OME-XML parsing without requiring ome-types
    try:
        ome_xml = getattr(tf, "ome_metadata", None)
    except Exception:
        ome_xml = None
    if not ome_xml or not isinstance(ome_xml, str):
        return None
    # naive extraction of Name="..." under Channel tags
    names: List[str] = []
    try:
        import re
        pattern = re.compile(r"<Channel[^>]*?Name=\"([^\"]+)\"", re.IGNORECASE)
        names = pattern.findall(ome_xml)
        if names:
            return names
    except Exception:
        return None
    return None




def _resolve_channel_axis(shape: Sequence[int], axes: Optional[str]) -> Tuple[int, int, int]:
    # Returns (c_axis, y_axis, x_axis) indices
    if axes:
        axes = axes.upper()
        # Common axes: TCZYX, CZYX, CYX, YXS, etc.
        # Prefer explicit channel axis 'C'; otherwise use samples 'S'.
        if ("C" in axes or "S" in axes) and "Y" in axes and "X" in axes:
            c_axis = axes.index("C") if "C" in axes else axes.index("S")
            y_axis = axes.index("Y")
            x_axis = axes.index("X")
            return c_axis, y_axis, x_axis
    # Fallback heuristics
    if len(shape) == 3:
        # Heuristic: if last dim small (<=8), treat last as channels (Y, X, C)
        if int(shape[2]) <= 8:
            return 2, 0, 1
        # else assume (C, Y, X)
        return 0, 1, 2
    elif len(shape) == 2:
        # single-channel image
        return -1, 0, 1
    elif len(shape) >= 4:
        # assume (..., C, Y, X)
        return len(shape) - 3, len(shape) - 2, len(shape) - 1
    raise ValueError(f"Unsupported image shape: {shape}")


def _iter_channel_blocks(arr: "zarr.core.Array", c_index: int, y_index: int, x_index: int, channel: int) -> Iterable[Tuple[Tuple[slice, ...], np.ndarray]]:
    # Build a slicer selecting specific channel and iterating over chunked Y,X blocks
    chunks = getattr(arr, "chunks", None)
    if chunks is None:
        # No chunking info; read whole channel
        slicer = [slice(None)] * arr.ndim
        if c_index >= 0:
            slicer[c_index] = slice(channel, channel + 1)
        data = np.asarray(arr[tuple(slicer)]).squeeze()
        yield (tuple(slicer), data)
        return
    # iterate over Y,X chunk grid
    y_chunks = chunks[y_index]
    x_chunks = chunks[x_index]
    y_len = arr.shape[y_index]
    x_len = arr.shape[x_index]
    y_starts = list(range(0, y_len, y_chunks))
    x_starts = list(range(0, x_len, x_chunks))
    for ys in y_starts:
        ye = min(ys + y_chunks, y_len)
        for xs in x_starts:
            xe = min(xs + x_chunks, x_len)
            slicer = [slice(None)] * arr.ndim
            if c_index >= 0:
                slicer[c_index] = slice(channel, channel + 1)
            slicer[y_index] = slice(ys, ye)
            slicer[x_index] = slice(xs, xe)
            data = np.asarray(arr[tuple(slicer)]).squeeze()
            yield (tuple(slicer), data)


def _scan_mask_labels_and_centroids(mask_arr: "zarr.core.Array") -> Tuple[
    np.ndarray,
    Dict[int, Tuple[float, float, int]],
    Dict[int, Tuple[int, int, int, int]],
    Dict[int, Tuple[float, float, float]],
]:
    # Returns sorted labels, dict: label -> (sum_x, sum_y, count), bounds dict: label -> (minx, miny, maxx, maxy),
    # and moments dict: label -> (sum_xx, sum_yy, sum_xy)
    labels_set: set[int] = set()
    accum: Dict[int, Tuple[float, float, int]] = {}
    bounds: Dict[int, Tuple[int, int, int, int]] = {}
    moments: Dict[int, Tuple[float, float, float]] = {}
    chunks = getattr(mask_arr, "chunks", None)
    if chunks is None:
        data = np.asarray(mask_arr[:]).astype(np.int64)
        ys, xs = np.nonzero(data)
        labs = data[ys, xs]
        for y, x, lab in zip(ys, xs, labs):
            if lab == 0:
                continue
            lab_i = int(lab)
            labels_set.add(lab_i)
            sX, sY, cnt = accum.get(lab_i, (0.0, 0.0, 0))
            accum[lab_i] = (sX + float(x), sY + float(y), cnt + 1)
            sXX, sYY, sXY = moments.get(lab_i, (0.0, 0.0, 0.0))
            moments[lab_i] = (sXX + float(x) * float(x), sYY + float(y) * float(y), sXY + float(x) * float(y))
            bx = bounds.get(lab_i)
            if bx is None:
                bounds[lab_i] = (int(x), int(y), int(x), int(y))
            else:
                minx, miny, maxx, maxy = bx
                bounds[lab_i] = (min(minx, int(x)), min(miny, int(y)), max(maxx, int(x)), max(maxy, int(y)))
    else:
        y_chunks = chunks[0]
        x_chunks = chunks[1]
        y_len = mask_arr.shape[0]
        x_len = mask_arr.shape[1]
        for ys in range(0, y_len, y_chunks):
            ye = min(ys + y_chunks, y_len)
            for xs in range(0, x_len, x_chunks):
                xe = min(xs + x_chunks, x_len)
                block = np.asarray(mask_arr[ys:ye, xs:xe]).astype(np.int64)
                ys_b, xs_b = np.nonzero(block)
                labs = block[ys_b, xs_b]
                for yb, xb, lab in zip(ys_b, xs_b, labs):
                    if lab == 0:
                        continue
                    lab_i = int(lab)
                    labels_set.add(lab_i)
                    sX, sY, cnt = accum.get(lab_i, (0.0, 0.0, 0))
                    gx = int(xs + xb)
                    gy = int(ys + yb)
                    accum[lab_i] = (sX + float(gx), sY + float(gy), cnt + 1)
                    sXX, sYY, sXY = moments.get(lab_i, (0.0, 0.0, 0.0))
                    moments[lab_i] = (sXX + float(gx) * float(gx), sYY + float(gy) * float(gy), sXY + float(gx) * float(gy))
                    bx = bounds.get(lab_i)
                    if bx is None:
                        bounds[lab_i] = (gx, gy, gx, gy)
                    else:
                        minx, miny, maxx, maxy = bx
                        bounds[lab_i] = (min(minx, gx), min(miny, gy), max(maxx, gx), max(maxy, gy))
    sorted_labels = np.array(sorted(labels_set), dtype=np.int64)
    return sorted_labels, accum, bounds, moments


def _prepare_output_table(n_cells: int, n_channels: int, dtype: str = "float16") -> np.ndarray:
    dt = np.float16 if dtype.lower() == "float16" else np.float32
    return np.zeros((n_cells, n_channels), dtype=dt)


def _accumulate_mean_for_channel(mask_arr: "zarr.core.Array", img_arr: "zarr.core.Array", c_index: int, y_index: int, x_index: int, channel: int, label_to_row: Dict[int, int]) -> Tuple[np.ndarray, np.ndarray]:
    # Returns per-row (cell) sums and counts
    n_cells = len(label_to_row)
    sums = np.zeros(n_cells, dtype=np.float64)
    counts = np.zeros(n_cells, dtype=np.int64)
    for slicer, img_block in _iter_channel_blocks(img_arr, c_index, y_index, x_index, channel):
        # Align mask block to the same Y,X slice
        ysl = slicer[y_index] if isinstance(slicer[y_index], slice) else slice(0, mask_arr.shape[0])
        xsl = slicer[x_index] if isinstance(slicer[x_index], slice) else slice(0, mask_arr.shape[1])
        mask_block = np.asarray(mask_arr[ysl, xsl])
        mb = mask_block
        hb, wb = img_block.shape
        if mb.shape != img_block.shape:
            mb = mb[:hb, :wb]
        # accumulate
        valid = mb > 0
        if not np.any(valid):
            continue
        labs = mb[valid]
        vals = img_block[valid]
        # vectorized aggregation per label
        # map labels to rows
        rows = np.vectorize(lambda l: label_to_row.get(int(l), -1))(labs)
        mask_rows = rows >= 0
        if not np.any(mask_rows):
            continue
        rows_use = rows[mask_rows]
        vals_use = vals[mask_rows].astype(np.float64)
        # accumulate sums and counts
        # group by rows
        unique_rows, inverse = np.unique(rows_use, return_inverse=True)
        sums_acc = np.bincount(inverse, weights=vals_use)
        counts_acc = np.bincount(inverse)
        sums[unique_rows] += sums_acc
        counts[unique_rows] += counts_acc
    return sums, counts


def _accumulate_median_for_channel(mask_arr: "zarr.core.Array", img_arr: "zarr.core.Array", c_index: int, y_index: int, x_index: int, channel: int, label_to_row: Dict[int, int], bins: int = 256) -> np.ndarray:
    # Streaming approximate median via per-label histogram
    n_cells = len(label_to_row)
    hist = np.zeros((n_cells, bins), dtype=np.int64)
    # determine intensity range from dtype
    # Peek dtype from a small block
    sample = None
    try:
        sample = np.asarray(next(_iter_channel_blocks(img_arr, c_index, y_index, x_index, channel))[1])
    except Exception:
        pass
    dtype = np.dtype(sample.dtype if sample is not None else np.float32)
    if np.issubdtype(dtype, np.integer):
        maxv = np.iinfo(dtype).max
        minv = np.iinfo(dtype).min
    else:
        minv, maxv = 0.0, 1.0
    rng = float(maxv - minv) if float(maxv - minv) > 0 else 1.0
    for slicer, img_block in _iter_channel_blocks(img_arr, c_index, y_index, x_index, channel):
        ysl = slicer[y_index] if isinstance(slicer[y_index], slice) else slice(0, mask_arr.shape[0])
        xsl = slicer[x_index] if isinstance(slicer[x_index], slice) else slice(0, mask_arr.shape[1])
        mask_block = np.asarray(mask_arr[ysl, xsl])
        hb, wb = img_block.shape
        if mask_block.shape != img_block.shape:
            mask_block = mask_block[:hb, :wb]
        valid = mask_block > 0
        if not np.any(valid):
            continue
        labs = mask_block[valid]
        vals = img_block[valid].astype(np.float32)
        # compute bin indices
        bin_idx = np.clip(((vals - minv) / rng * (bins - 1)).astype(np.int32), 0, bins - 1)
        rows = np.vectorize(lambda l: label_to_row.get(int(l), -1))(labs)
        mask_rows = rows >= 0
        if not np.any(mask_rows):
            continue
        rows_use = rows[mask_rows]
        bins_use = bin_idx[mask_rows]
        # update histograms
        for r, b in zip(rows_use, bins_use):
            hist[r, b] += 1
    # compute median from histograms
    med = np.zeros(n_cells, dtype=np.float32)
    half = hist.sum(axis=1) // 2
    cum = np.cumsum(hist, axis=1)
    for i in range(n_cells):
        k = np.searchsorted(cum[i], half[i], side="left")
        # map bin to intensity
        med[i] = float(minv) + (float(k) / float(bins - 1)) * rng
    return med


# Use spatial_writer IO; keep local helpers removed for modularity


def _save_outputs(
    features: np.ndarray,
    labels: np.ndarray,
    centroids_xy: np.ndarray,
    marker_names: List[str],
    image_id: str,
    output_dir: str,
    output_dtype: str = "float16",
    seg_polys: Optional[List[Any]] = None,
) -> None:
    if not HAS_ANNDATA or not HAS_SPATIALDATA:
        raise RuntimeError("Both AnnData and SpatialData are required; output is SpatialData only as requested.")
    # Build AnnData table
    X = features.astype(np.float16 if output_dtype.lower() == "float16" else np.float32)
    adata = ad.AnnData(X=X)
    adata.var_names = marker_names
    adata.obs_names = [str(int(l)) for l in labels]
    adata.obs["CellId"] = labels.astype(np.int64)
    adata.obs["imageid"] = np.array([image_id] * len(labels), dtype=object)
    adata.obs["X_centroid"] = centroids_xy[:, 0].astype(np.float32)
    adata.obs["Y_centroid"] = centroids_xy[:, 1].astype(np.float32)
    adata.obsm["spatial"] = centroids_xy.astype(np.float32)
    adata.uns["all_markers"] = list(marker_names)
    
    # Store raw expression matrix in adata.raw.X
    # Create a copy of the AnnData object for raw storage
    adata.raw = adata.copy()
    # Optional region linkage for segmentation shapes
    if seg_polys is not None and HAS_SHAPELY:
        adata.obs["region"] = np.array(["segmentation_mask"] * adata.n_obs, dtype=object)
        adata.obs["instance_id"] = labels.astype(np.int64)

    # Determine output path: accept explicit .zarr path or directory
    target = Path(output_dir)
    if str(target).lower().endswith(".zarr") or target.suffix.lower() == ".zarr":
        zarr_dir = target
        zarr_dir.parent.mkdir(parents=True, exist_ok=True)
    else:
        target.mkdir(parents=True, exist_ok=True)
        zarr_dir = target / f"{image_id}.zarr"
    table_name = f"{image_id}_cells"
    
    # Check if zarr already exists to determine update vs create
    update_spatial_path = str(zarr_dir) if zarr_dir.exists() else None

    # If segmentation polygons provided and shapes deps available, write combined shapes+table
    if seg_polys is not None and HAS_SHAPELY:
        # Build shapes layer GeoDataFrame
        try:
            shapes_index = adata.obs_names.copy()
            gdf = gpd.GeoDataFrame(
                {
                    "geometry": seg_polys,
                    "region": ["segmentation_mask"] * len(seg_polys),
                    "instance_id": labels.astype(np.int64),
                },
                geometry="geometry",
                index=shapes_index,
            )
            shapes_layer = ShapesModel.parse(gdf)
        except Exception as e:
            _warn(f"Failed to build shapes layer from segmentation polygons: {e}; writing table only.")
            shapes_layer = None

        # Build table layer, optionally linked to shapes via region/instance_id
        table_layer = None
        try:
            if shapes_layer is not None and 'region' in adata.obs and 'instance_id' in adata.obs:
                table_layer = TableModel.parse(
                    adata,
                    region="segmentation_mask",
                    region_key="region",
                    instance_key="instance_id",
                )
        except Exception:
            table_layer = None

        # Compose SpatialData and write/update
        if update_spatial_path:
            # Read existing dataset if possible
            existing = None
            try:
                if hasattr(sdata, "read_zarr"):
                    existing = sdata.read_zarr(update_spatial_path)  # type: ignore
                elif hasattr(sdata, "io") and hasattr(sdata.io, "read_zarr"):
                    existing = sdata.io.read_zarr(update_spatial_path)  # type: ignore
            except Exception:
                existing = None

            # Prepare layers
            shapes_dict = dict(getattr(existing, "shapes", {}) or {})
            if shapes_layer is not None:
                shapes_dict["segmentation_mask"] = shapes_layer

            tables_dict = dict(getattr(existing, "tables", {}) or {})
            if table_layer is not None:
                tables_dict[table_name] = table_layer
            else:
                tables_dict[table_name] = adata

            sd = SpatialData(
                shapes=shapes_dict if shapes_dict else None,
                images=getattr(existing, "images", None) or None,
                labels=getattr(existing, "labels", None) or None,
                points=getattr(existing, "points", None) or None,
                tables=tables_dict if tables_dict else None,
            )
            if hasattr(sdata, "write_zarr"):
                sdata.write_zarr(sd, update_spatial_path, overwrite=True)  # type: ignore
            elif hasattr(sdata, "io") and hasattr(sdata.io, "write_zarr"):
                sdata.io.write_zarr(sd, update_spatial_path, overwrite=True)  # type: ignore
            elif hasattr(sd, "write"):
                try:
                    sd.write(update_spatial_path, overwrite=True)
                except TypeError:
                    sd.write(update_spatial_path)
            _info(f"Updated SpatialData (table + segmentation shapes) at: {update_spatial_path}")
            return
        else:
            # New dataset with shapes + table
            try:
                from shutil import rmtree
                if zarr_dir.exists():
                    rmtree(zarr_dir, ignore_errors=True)
            except Exception:
                pass
            if shapes_layer is not None and table_layer is not None:
                sd = SpatialData(shapes={"segmentation_mask": shapes_layer}, tables={table_name: table_layer})
            elif shapes_layer is not None:
                sd = SpatialData(shapes={"segmentation_mask": shapes_layer}, tables={table_name: adata})
            else:
                sd = SpatialData(tables={table_name: adata})
            if hasattr(sdata, "write_zarr"):
                sdata.write_zarr(sd, str(zarr_dir), overwrite=True)  # type: ignore
            elif hasattr(sdata, "io") and hasattr(sdata.io, "write_zarr"):
                sdata.io.write_zarr(sd, str(zarr_dir), overwrite=True)  # type: ignore
            elif hasattr(sd, "write"):
                try:
                    sd.write(str(zarr_dir), overwrite=True)
                except TypeError:
                    sd.write(str(zarr_dir))
            _info(f"Saved SpatialData (table + segmentation shapes): {zarr_dir}")
            return

    # Fallback: write table only via modular writer
    if write_spatialdata_table is None or update_spatialdata_table is None:
        raise RuntimeError("spatial_writer module is required for output; not found.")
    if update_spatial_path:
        update_spatialdata_table(adata, update_spatial_path, table_name=table_name)
        _info(f"Updated SpatialData at: {update_spatial_path}")
        return
    write_spatialdata_table(adata, str(zarr_dir), table_name=table_name)
    _info(f"Saved SpatialData table: {zarr_dir}")


def spatialFeatureTable(
    inputPath: str,
    maskPath: str,
    markersCsvPath: Optional[str] = None,
    stat: str = "mean",
    outputPath: Optional[str] = None,
    outputDtype: str = "float16",
    verbose: bool = False,
    addSegmentationShapes: bool = True,
) -> str:
    """
    Notebook-friendly function to build a per-cell feature table from a multiplexed image and a segmentation mask,
    writing a SpatialData table. Optionally writes the segmentation mask as per-cell polygons under
    shapes['segmentation_mask'] when geometry dependencies are available. Returns the path to the created/updated
    SpatialData zarr. If the output zarr already exists, it will be updated; otherwise, a new one will be created.

    Parameters (camelCase):
    - inputPath: Path to input multiplexed OME-TIFF/TIFF image
    - maskPath: Path to segmentation mask TIFF/OME-TIFF with integer labels; 0 is background
    - markersCsvPath: Optional CSV with 'marker_name' column to name channels
    - stat: "mean" (default) or "median"; median uses streaming histogram approximation
    - outputPath: Output directory or explicit .zarr path; defaults to sibling directory named by image with _mIF suffix
    - outputDtype: "float16" (default) or "float32"
    - verbose: If True, prints progress
    - addSegmentationShapes: If True (default), attempts to convert the mask into polygons per cell and
      save them in SpatialData under shapes['segmentation_mask'] (requires shapely + geopandas).
    """
    if not HAS_TIFFFILE:
        raise RuntimeError("tifffile is required; zarr is optional for chunked streaming.")

    img_arr, img_shape, img_axes = _image_as_zarr_utils(inputPath)
    if img_arr is None or img_shape is None:
        raise RuntimeError("Failed to open input image as Zarr; ensure OME-TIFF/TIFF.")
    mask_arr = _mask_as_zarr_utils(maskPath)
    if mask_arr is None:
        raise RuntimeError("Failed to open mask as Zarr; ensure TIFF with same XY size.")

    c_axis, y_axis, x_axis = _resolve_channel_axis(img_shape, img_axes)
    n_channels = 1 if c_axis == -1 else int(img_shape[c_axis])
    img_y = int(img_shape[y_axis])
    img_x = int(img_shape[x_axis])
    if mask_arr.shape[0] != img_y or mask_arr.shape[1] != img_x:
        raise RuntimeError(f"Mask size {mask_arr.shape} does not match image XY {(img_y, img_x)}")

    # Resolve marker names
    marker_names = None
    tf = None
    try:
        tf = tiff.TiffFile(inputPath)
        marker_names = _parse_ome_channel_names(tf)
    except Exception:
        marker_names = None
    if not marker_names:
        csv_names = _read_markers_csv(markersCsvPath)
        if csv_names and len(csv_names) == n_channels:
            marker_names = csv_names
        elif csv_names and len(csv_names) != n_channels:
            _warn(f"CSV marker count {len(csv_names)} does not match channel count {n_channels}; ignoring CSV.")
    if not marker_names:
        marker_names = [f"marker-{i+1}" for i in range(n_channels)]
        _warn("Using generic marker names; OME metadata missing and no valid CSV provided.")
    if len(marker_names) != n_channels:
        raise RuntimeError("Marker name count must match channel count.")

    # Image id and output dir
    image_stem = Path(inputPath).stem
    image_id = f"{image_stem}_mIF"
    output_dir = outputPath or str((Path(inputPath).parent / image_id))

    # Pass 1: scan mask to get labels and centroids
    if verbose:
        _info("Scanning mask for labels and centroids...")
    sorted_labels, centroid_accum, bounds_accum, moments_accum = _scan_mask_labels_and_centroids(mask_arr)
    if sorted_labels.size == 0:
        features = _prepare_output_table(0, n_channels, dtype=outputDtype)
        _save_outputs(
            features,
            np.array([], dtype=np.int64),
            np.zeros((0, 2), dtype=np.float32),
            marker_names,
            image_id,
            output_dir,
            outputDtype,
            seg_polys=None,
        )
        # Return explicit .zarr path when provided, else directory + image-based name
        out_target = Path(output_dir)
        zarr_dir = (
            str(out_target) if (str(out_target).lower().endswith(".zarr") or out_target.suffix.lower() == ".zarr")
            else str(out_target / f"{image_id}.zarr")
        )
        return zarr_dir

    label_to_row: Dict[int, int] = {int(l): i for i, l in enumerate(sorted_labels.tolist())}
    centroids = np.zeros((len(sorted_labels), 2), dtype=np.float32)
    for lab in sorted_labels:
        sX, sY, cnt = centroid_accum[int(lab)]
        if cnt > 0:
            centroids[label_to_row[int(lab)], 0] = sX / float(cnt)
            centroids[label_to_row[int(lab)], 1] = sY / float(cnt)
        else:
            centroids[label_to_row[int(lab)], :] = 0.0

    # Allocate feature matrix (float32 initially, convert at end)
    features = np.zeros((len(sorted_labels), n_channels), dtype=np.float32)

    # Per-channel accumulation with single-line streaming progress
    t0 = time.time()
    for c in range(n_channels):
        if stat == "mean":
            sums, counts = _accumulate_mean_for_channel(
                mask_arr, img_arr, c_axis, y_axis, x_axis, c, label_to_row
            )
            # avoid division by zero
            with np.errstate(divide="ignore", invalid="ignore"):
                ch_feat = np.where(counts > 0, sums / counts, 0.0).astype(np.float32)
        else:
            ch_feat = _accumulate_median_for_channel(
                mask_arr, img_arr, c_axis, y_axis, x_axis, c, label_to_row, bins=256
            )
        features[:, c] = ch_feat
        if verbose:
            elapsed = time.time() - t0
            pct = min(100, int(100 * (c + 1) / n_channels))
            sys.stdout.write(
                f"\r[info] Processing channels: {c+1}/{n_channels} ({pct}%)  elapsed {elapsed:.1f}s"
            )
            sys.stdout.flush()
    if verbose:
        sys.stdout.write("\n")

    # Save outputs to SpatialData only
    seg_polys: Optional[List[Any]] = None
    if addSegmentationShapes and HAS_SHAPELY:
        try:
            seg_polys = _build_label_ellipse_polygons(centroid_accum, moments_accum, bounds_accum, sorted_labels, scale_sigma=2.0, n_points=16)
        except Exception as e:
            seg_polys = None
            _warn(f"Could not build segmentation polygons (ellipse): {e}")
    _save_outputs(
        features,
        sorted_labels.astype(np.int64),
        centroids,
        marker_names,
        image_id,
        output_dir,
        outputDtype,
        seg_polys=seg_polys,
    )

    out_target = Path(output_dir)
    zarr_dir = (
        str(out_target) if (str(out_target).lower().endswith(".zarr") or out_target.suffix.lower() == ".zarr")
        else str(out_target / f"{image_id}.zarr")
    )
    return zarr_dir


def spatialFeatureTableFromDict(cfg: Dict) -> str:
    """Dictionary-based camelCase wrapper for notebook/CLI interop.

    Required keys: inputPath, maskPath
    Optional keys: markersCsvPath, stat, outputPath, outputDtype, verbose, addSegmentationShapes
    """
    inputPath = cfg.get("inputPath")
    maskPath = cfg.get("maskPath")
    if not inputPath or not maskPath:
        raise ValueError("Both 'inputPath' and 'maskPath' are required.")
    return spatialFeatureTable(
        inputPath=inputPath,
        maskPath=maskPath,
        markersCsvPath=cfg.get("markersCsvPath"),
        stat=str(cfg.get("stat", "mean")),
        outputPath=cfg.get("outputPath"),
        outputDtype=str(cfg.get("outputDtype", "float16")),
        verbose=bool(cfg.get("verbose", False)),
        addSegmentationShapes=bool(cfg.get("addSegmentationShapes", True)),
    )


def _build_label_bbox_polygons(bounds: Dict[int, Tuple[int, int, int, int]], labels: np.ndarray) -> List[Any]:
    """Build fast approximate polygons using axis-aligned bounding boxes per label.

    Returns shapely Polygons covering [minx, maxx+1] × [miny, maxy+1] in pixel coordinates.
    """
    if not HAS_SHAPELY:
        raise RuntimeError("Shapely/GeoPandas are required to build segmentation polygons.")
    geoms: List[Any] = []
    for lab in labels.tolist():
        bx = bounds.get(int(lab))
        if bx is None:
            geoms.append(sg.Polygon())
            continue
        minx, miny, maxx, maxy = bx
        # create rectangle covering inclusive pixel extents
        geoms.append(
            sg.Polygon([
                (float(minx), float(miny)),
                (float(maxx + 1), float(miny)),
                (float(maxx + 1), float(maxy + 1)),
                (float(minx), float(maxy + 1)),
            ])
        )
    return geoms


def _build_label_ellipse_polygons(
    centroid_accum: Dict[int, Tuple[float, float, int]],
    moments_accum: Dict[int, Tuple[float, float, float]],
    bounds_accum: Dict[int, Tuple[int, int, int, int]],
    labels: np.ndarray,
    scale_sigma: float = 2.0,
    n_points: int = 16,
) -> List[Any]:
    """Build fast approximate ellipses per label using first and second moments.

    - Center at centroid
    - Radii r_i = scale_sigma * sqrt(eigenvalues of covariance)
    - Orientation from covariance eigenvectors
    Fallback to bounding boxes if covariance is degenerate or counts are too small.
    """
    if not HAS_SHAPELY:
        raise RuntimeError("Shapely/GeoPandas are required to build segmentation polygons.")
    geoms: List[Any] = []
    for lab in labels.tolist():
        sX, sY, cnt = centroid_accum.get(int(lab), (0.0, 0.0, 0))
        if cnt <= 1:
            # Fallback to bbox for tiny objects
            geoms.append(_build_label_bbox_polygons(bounds_accum, np.array([int(lab)]))[0])
            continue
        cx = sX / float(cnt)
        cy = sY / float(cnt)
        sXX, sYY, sXY = moments_accum.get(int(lab), (0.0, 0.0, 0.0))
        Ex2 = sXX / float(cnt)
        Ey2 = sYY / float(cnt)
        Exy = sXY / float(cnt)
        # covariance elements
        cov_xx = max(Ex2 - cx * cx, 0.0)
        cov_yy = max(Ey2 - cy * cy, 0.0)
        cov_xy = Exy - cx * cy
        cov = np.array([[cov_xx, cov_xy], [cov_xy, cov_yy]], dtype=np.float64)
        try:
            w, v = np.linalg.eigh(cov)
            # ensure non-negative radii
            r0 = float(scale_sigma * np.sqrt(max(w[0], 1e-8)))
            r1 = float(scale_sigma * np.sqrt(max(w[1], 1e-8)))
            # If radii are near zero, fallback
            if not np.isfinite(r0 + r1) or (r0 + r1) < 1e-6:
                geoms.append(_build_label_bbox_polygons(bounds_accum, np.array([int(lab)]))[0])
                continue
            # sample ellipse points
            theta = np.linspace(0.0, 2.0 * np.pi, num=max(8, int(n_points)), endpoint=False)
            # principal axes vectors (columns of v)
            a = v[:, 0].astype(np.float64)
            b = v[:, 1].astype(np.float64)
            pts = []
            for t in theta:
                x = cx + r0 * (a[0] * np.cos(t)) + r1 * (b[0] * np.sin(t))
                y = cy + r0 * (a[1] * np.cos(t)) + r1 * (b[1] * np.sin(t))
                pts.append((float(x), float(y)))
            geoms.append(sg.Polygon(pts))
        except Exception:
            # Fallback to bbox
            geoms.append(_build_label_bbox_polygons(bounds_accum, np.array([int(lab)]))[0])
    return geoms


def main():
    ap = argparse.ArgumentParser(description="Build per-cell feature table from multiplexed image and mask.")
    ap.add_argument("--input", required=True, help="Input multiplexed OME-TIFF/TIFF image")
    ap.add_argument("--mask", required=True, help="Segmentation mask TIFF/OME-TIFF with integer labels; 0 is background")
    ap.add_argument(
        "--markers-csv",
        default=None,
        help="Optional CSV with 'marker_name' column to name channels",
    )
    ap.add_argument(
        "--stat",
        choices=["mean", "median"],
        default="mean",
        help="Statistic to compute per channel; median is approximate via hist",
    )
    ap.add_argument(
        "--output",
        default=None,
        help="Output directory or explicit .zarr path; defaults to sibling directory named by image with _mIF suffix",
    )
    ap.add_argument(
        "--output-dtype",
        choices=["float16", "float32"],
        default="float16",
        help="Output feature dtype",
    )
    ap.add_argument("--verbose", action="store_true", help="Show verbose progress lines")

    args = ap.parse_args()

    # Delegate to the camelCase notebook-friendly function
    zarr_dir = spatialFeatureTable(
        inputPath=str(args.input),
        maskPath=str(args.mask),
        markersCsvPath=args.markers_csv,
        stat=args.stat,
        outputPath=args.output,
        outputDtype=args.output_dtype,
        verbose=args.verbose,
    )
    _info(f"Saved SpatialData table: {zarr_dir}")


if __name__ == "__main__":
    main()