
#!/usr/bin/env python3
"""
Updated he_embedder.py (no hard-coded patch size)
-------------------------------------------------
This script extracts ViT patch or token embeddings from H&E images and writes the
output as a SpatialData (.zarr) dataset or, if SpatialData is not available,
as an AnnData (.h5ad) table. It supports both patch‑level pooled outputs and
token‑level outputs; in the latter case, each token is treated as its own
observation, with coordinates derived from subdividing each patch into a grid.

Key changes (relative to the previous version):
* Patch size and token size computation in writers now strictly use cfg["patch_size"]
  and never fall back to a hard-coded constant. This avoids latent "224" defaults.
"""

import argparse
import sys
import traceback
from pathlib import Path
from typing import Optional, List, Dict, Any, Union

import numpy as np
from PIL import Image, ImageFile

# Prevent DecompressionBombWarning for large tiles
ImageFile.LOAD_TRUNCATED_IMAGES = True
Image.MAX_IMAGE_PIXELS = None

# Optional heavy imports (torch, timm) are required for inference
# Make torch import optional so config loading doesn't fail at package import time
HAS_TORCH = False
try:
    import torch  # type: ignore
    HAS_TORCH = True
except Exception:
    HAS_TORCH = False
try:
    from histotuner.utils import read_he_image as _read_he_image
except Exception:
    _read_he_image = None

# Geometry libraries (used by tissue detection helpers)
try:
    import shapely.geometry as sg  # for polygons
    HAS_SHAPELY = True
except Exception:
    HAS_SHAPELY = False

# Import local tissue.py pipeline (self-contained transforms + BinaryMask)
try:
    from . import tissue as tissue_mod  # provides _tissue_mask and BinaryMask
    from .tissue import BinaryMask  # polygonization
    HAS_TISSUEPY = True
except Exception:
    try:
        import tissue as tissue_mod  # fallback when executed as a script
        from tissue import BinaryMask
        HAS_TISSUEPY = True
    except Exception as _e:
        print("[tissue import warning]", repr(_e), file=sys.stderr)
        HAS_TISSUEPY = False

# Spatial writer helpers
from .spatial_writer import (
    build_embedding_table,
    update_spatialdata_table,
    update_spatialdata_shapes,
    build_tile_shapes,
    save_as_spatialdata,
    save_tissue_spatialdata,
)

# Import embedding extraction utilities
try:
    from .extractEmbedding import extractEmbeddingCamel
except Exception as _e:
    extractEmbeddingCamel = None  # type: ignore
    print("[embed import warning] extractEmbeddingCamel not available:", repr(_e), file=sys.stderr)

# Add optional segmentation import (guard missing deps like cellpose)
HAS_SEGMENT = False
try:
    from .segment import segmentImage
    HAS_SEGMENT = True
except Exception as _e:
    segmentImage = None  # type: ignore
    print("[segment import warning] segmentImage not available:", repr(_e), file=sys.stderr)

# Add optional spatial feature table import (guard missing deps)
HAS_SFT = False
try:
    from .spatialfeaturetable import spatialFeatureTable
    HAS_SFT = True
except Exception as _e:
    spatialFeatureTable = None  # type: ignore
    print("[spatial import warning] spatialFeatureTable not available:", repr(_e), file=sys.stderr)

# Add optional token→cell mapper import (guard missing deps)
HAS_TCM = False
try:
    from .token_cell_mapper import tokenCellMapper
    HAS_TCM = True
except Exception as _e:
    tokenCellMapper = None  # type: ignore
    print("[token-cell import warning] tokenCellMapper not available:", repr(_e), file=sys.stderr)


def _normalizeConfigKeys(cfg: Dict[str, Any]) -> Dict[str, Any]:
    """Normalize common camelCase keys to snake_case for internal use."""
    mapping = {
        # General
        "input": "input",
        "output": "output",
        "modelId": "model_id",
        "outputType": "output_type",
        "inputSize": "input_size",
        "patchSize": "patch_size",
        "stride": "stride",
        "mean": "mean",
        "std": "std",
        "batchSize": "batch_size",
        "device": "device",
        "streamWrite": "stream_write",
        "targetMpp": "target_mpp",
        "sourceMpp": "source_mpp",
        "slideId": "slide_id",
        # Tissue detection
        "runTissueDetection": "run_tissue_detection",
        "tissueThumbMax": "tissue_thumb_max",
        "tissueToHsv": "tissue_to_hsv",
        "tissueFilterArtifacts": "tissue_filter_artifacts",
        "tissueBlurKsize": "tissue_blur_ksize",
        "tissueThreshold": "tissue_threshold",
        "tissueMorphKsize": "tissue_morph_ksize",
        "tissueMorphNIter": "tissue_morph_n_iter",
        "tissueMinTissueArea": "tissue_min_tissue_area",
        "tissueMinHoleArea": "tissue_min_hole_area",
        "tissueDetectHoles": "tissue_detect_holes",
        "tissueRefineLevel": "tissue_refine_level",
        "onEmptyTissue": "on_empty_tissue",
        "showImage": "show_image",
        # Segmentation
        "runSegmentation": "run_segmentation",
        "segmentationOutputPath": "segmentation_output_path",
        "segmentationPatchSize": "segmentation_patch_size",
        "segmentationStride": "segmentation_stride",
        "segmentationPretrainedModel": "segmentation_pretrained_model",
        "segmentationDiameter": "segmentation_diameter",
        "segmentationTissueThreshold": "segmentation_tissue_threshold",
        "segmentationBatchSize": "segmentation_batch_size",
        "segmentationProgressInterval": "segmentation_progress_interval",
        "segmentationVerbose": "segmentation_verbose",
        "segmentationTileOverlap": "segmentation_tile_overlap",
        "segmentationMergeOverlaps": "segmentation_merge_overlaps",
        "segmentationCellprobThreshold": "segmentation_cellprob_threshold",
        "segmentationFlowThreshold": "segmentation_flow_threshold",
        "segmentationNormalize": "segmentation_normalize",
        "segmentationModelType": "segmentation_model_type",
        "segmentationMaxPatches": "segmentation_max_patches",
        # Spatial Feature Table
        "runSpatialFeatureTable": "run_spatial_feature_table",
        "spatialInputPath": "sft_input_path",
        "spatialMaskPath": "sft_mask_path",
        "spatialMarkersCsvPath": "sft_markers_csv_path",
        "spatialStat": "sft_stat",
        "spatialOutputPath": "sft_output_path",
        "spatialOutputDtype": "sft_output_dtype",
        "spatialVerbose": "sft_verbose",
        "spatialAddSegmentationShapes": "sft_add_segmentation_shapes",
        "spatialRunTokenCellMapper": "sft_run_token_cell_mapper",
    }
    out = dict(cfg)
    for k, v in list(cfg.items()):
        if k in mapping:
            out[mapping[k]] = v
    return out


# Load YAML config
def load_config(path: str) -> Dict[str, Any]:
    import yaml as _yaml
    p = Path(path)
    if not p.exists():
        raise FileNotFoundError(f"Config not found: {p}")
    with open(p, "r", encoding="utf-8") as f:
        cfg = _yaml.safe_load(f)
    if not isinstance(cfg, dict):
        raise ValueError("Config must be a YAML mapping (dict).")
    # Normalize camelCase → snake_case to accept notebook/CLI configs uniformly
    cfg = _normalizeConfigKeys(cfg)
    required = ["input", "output", "patch_size", "stride", "batch_size", "model_id", "input_size", "mean", "std", "output_type", "device"]
    for k in required:
        if k not in cfg:
            raise ValueError(f"Missing required config key: {k}")
    if not (isinstance(cfg["input_size"], (list, tuple)) and len(cfg["input_size"]) == 2):
        raise ValueError("input_size must be [H, W].")
    for k in ("mean", "std"):
        if not (isinstance(cfg[k], (list, tuple)) and len(cfg[k]) == 3):
            raise ValueError(f"{k} must be length‑3.")
    cfg.setdefault("target_mpp", None)
    cfg.setdefault("source_mpp", None)
    cfg.setdefault("timm_kwargs", {})
    cfg.setdefault("stream_write", True)
    cfg.setdefault("export_coordinate_frame", "level0")
    # Remove existing_spatial - use output path for update-or-create logic
    cfg.setdefault("run_tissue_detection", False)
    cfg.setdefault("run_embedding_extraction", True)
    cfg.setdefault("filter_to_tissue", True)
    cfg.setdefault("tissue_thumb_max", 2048)
    return cfg


def runEmbeddingPipeline(cfgCamel: Dict[str, Any]) -> Optional[Dict[str, Any]]:
    """Run the embedding pipeline given a (potentially camelCase) config dict."""
    cfg = _normalizeConfigKeys(cfgCamel)
    # Validate minimal required keys
    required = [
        "input", "output", "patch_size", "stride", "batch_size",
        "model_id", "input_size", "mean", "std", "output_type", "device",
    ]
    for k in required:
        if k not in cfg:
            raise ValueError(f"Missing required config key: {k}")
    # Defaults
    cfg.setdefault("target_mpp", None)
    cfg.setdefault("source_mpp", None)
    cfg.setdefault("timm_kwargs", {})
    cfg.setdefault("stream_write", True)
    cfg.setdefault("export_coordinate_frame", "level0")

    cfg.setdefault("run_tissue_detection", False)
    cfg.setdefault("run_embedding_extraction", True)
    cfg.setdefault("filter_to_tissue", True)
    cfg.setdefault("tissue_thumb_max", 2048)
    cfg.setdefault("tissue_s_thresh", 20)
    cfg.setdefault("tissue_v_thresh", 250)
    cfg.setdefault("tissue_to_hsv", False)
    cfg.setdefault("tissue_filter_artifacts", True)
    cfg.setdefault("tissue_blur_ksize", 17)
    cfg.setdefault("tissue_threshold", 7)
    cfg.setdefault("tissue_morph_n_iter", 3)
    cfg.setdefault("tissue_morph_ksize", 7)
    cfg.setdefault("tissue_min_tissue_area", 1e-3)
    cfg.setdefault("tissue_min_hole_area", 1e-5)
    cfg.setdefault("tissue_detect_holes", True)
    cfg.setdefault("tissue_refine_level", None)
    cfg.setdefault("on_empty_tissue", "proceed")
    cfg.setdefault("show_image", False)

    # Build minimal meta prior to extraction (device, mpp, coord frame)
    in_obj = cfg.get("input")
    source_mpp = None
    is_wsi = False
    try:
        from .utils import is_wsi as _is_wsi, read_wsi as _read_wsi
        if isinstance(in_obj, str) and _is_wsi(in_obj):
            is_wsi = True
            try:
                _slide, source_mpp = _read_wsi(str(in_obj))
            except Exception:
                source_mpp = None
    except Exception:
        pass
    # Allow override from config
    if cfg.get("source_mpp") is not None:
        try:
            source_mpp = float(cfg.get("source_mpp"))
        except Exception:
            print(f"[mpp] Invalid source_mpp override: {cfg.get('source_mpp')}", file=sys.stderr)
    eff_mpp = cfg.get("target_mpp", source_mpp) if is_wsi else None
    coord_frame = "target-mpp virtual pixels" if is_wsi else "pixels at effective mpp"
    meta = {
        "input_path": str(cfg["input"]),
        "model_name": str(cfg["model_id"]),
        "config_path": None,
        "patch_size": int(cfg["patch_size"]),
        "stride": int(cfg["stride"]),
        "effective_mpp": float(eff_mpp) if eff_mpp is not None else None,
        "source_mpp": float(source_mpp) if source_mpp is not None else None,
        "device": str(cfg.get("device", "auto")),
        "pyramid_level": 0,
        "level_downsample": 1.0,
        "coordinate_frame": coord_frame,
        "slide_id": str(cfg.get("slide_id") or Path(str(cfg["input"])) .stem),
    }

    wrote_any = False
    tissue_polys = None
    tissue_written_pre_seg = False
    tissue_geoms_list = None

    # ----- Tissue detection first, with fallback to full image when requested -----
    if bool(cfg.get("run_tissue_detection", False)) and HAS_TISSUEPY:
        try:
            from shapely.affinity import scale as _scale
            input_path = Path(str(cfg["input"]))
            # Decide WSI vs regular image using utils
            try:
                from .utils import is_wsi as _is_wsi, read_wsi as _read_wsi
                is_input_wsi = _is_wsi(str(input_path))
                # Special-case: route OME-TIFF to regular-image tissue path to align with standalone generateTissueMask
                inp_lower = str(input_path).lower()
                if inp_lower.endswith(".ome.tif") or inp_lower.endswith(".ome.tiff"):
                    is_input_wsi = False
                    print("[tissue] OME-TIFF detected; using regular-image tissue path")
            except Exception:
                is_input_wsi = False
            print(f"[tissue] Running detection (WSI={is_input_wsi})")
            if is_input_wsi:
                # Read WSI slide and build thumbnail
                slide, src_mpp_from_read = None, None
                try:
                    from .utils import read_wsi as _read_wsi
                    slide, src_mpp_from_read = _read_wsi(str(input_path))
                except Exception:
                    slide = None
                    src_mpp_from_read = None
                if slide is not None:
                    W0, H0 = slide.level_dimensions[0]
                    thumb_max = int(cfg.get("tissue_thumb_max") or 2048)
                    thumb = slide.get_thumbnail((thumb_max, thumb_max)).convert("RGB")
                    arr_thumb = np.array(thumb)
                    # Primary: Lazyslide-style artifact filter via tissue.py
                    mask_thumb = tissue_mod.tissueMask(
                        image=arr_thumb,
                        toHsv=bool(cfg.get("tissue_to_hsv", False)),
                        filterArtifacts=bool(cfg.get("tissue_filter_artifacts", True)),
                        blurKsize=int(cfg.get("tissue_blur_ksize", 17)),
                        threshold=(int(cfg.get("tissue_threshold")) if cfg.get("tissue_threshold") is not None else None),
                        morphKsize=int(cfg.get("tissue_morph_ksize", 7)),
                        morphNIter=int(cfg.get("tissue_morph_n_iter", 3)),
                    )
                    tgdf_thumb = BinaryMask(mask_thumb).to_polygons(
                        min_area=float(cfg.get("tissue_min_tissue_area", 0.0)),
                        min_hole_area=float(cfg.get("tissue_min_hole_area", 0.0)),
                        detect_holes=bool(cfg.get("tissue_detect_holes", True)),
                    )
                    print(f"[tissue] WSI mask method: tissue.py artifact-filter; polygons={len(tgdf_thumb.geometry) if hasattr(tgdf_thumb, 'geometry') else 'unknown'}")
                    # Scale polygons from thumbnail to level-0
                    xfact_l0 = float(W0) / float(thumb.width)
                    yfact_l0 = float(H0) / float(thumb.height)
                    tissue_geoms_list = [
                        _scale(g, xfact=xfact_l0, yfact=yfact_l0, origin=(0, 0)) for g in tgdf_thumb.geometry
                    ]
                    # If exporting in virtual frame, downscale by mpp_scale now so writers can up-scale consistently
                    src_mpp_use = source_mpp if source_mpp is not None else src_mpp_from_read
                    mpp_scale = None
                    if src_mpp_use is not None and eff_mpp is not None:
                        try:
                            mpp_scale = float(eff_mpp) / float(src_mpp_use)
                        except Exception:
                            mpp_scale = None
                    if mpp_scale is not None:
                        tissue_geoms_list = [
                            _scale(g, xfact=(1.0 / float(mpp_scale)), yfact=(1.0 / float(mpp_scale)), origin=(0, 0))
                            for g in tissue_geoms_list
                        ]
                    tissue_polys = tissue_geoms_list
                    print(f"[tissue] Detected {len(tissue_polys) if hasattr(tissue_polys, '__len__') else 'unknown'} polygons")
                else:
                    tissue_polys = None
            else:
                # Regular image branch
                if _read_he_image is None:
                    raise RuntimeError("read_he_image unavailable; cannot run tissue detection.")
                im = _read_he_image(str(input_path))
                arr = np.array(im)
                mask = tissue_mod.tissueMask(
                    image=arr,
                    toHsv=bool(cfg.get("tissue_to_hsv", False)),
                    filterArtifacts=bool(cfg.get("tissue_filter_artifacts", True)),
                    blurKsize=int(cfg.get("tissue_blur_ksize", 17)),
                    threshold=(int(cfg.get("tissue_threshold")) if cfg.get("tissue_threshold") is not None else None),
                    morphKsize=int(cfg.get("tissue_morph_ksize", 7)),
                    morphNIter=int(cfg.get("tissue_morph_n_iter", 3)),
                )
                tgdf = BinaryMask(mask).to_polygons(
                    min_area=float(cfg.get("tissue_min_tissue_area", 0.0)),
                    min_hole_area=float(cfg.get("tissue_min_hole_area", 0.0)),
                    detect_holes=bool(cfg.get("tissue_detect_holes", True)),
                )
                tissue_polys = list(tgdf.geometry)
                print(f"[tissue] Detected {len(tissue_polys)} polygons")
            
            # Fallback: if no tissue polygons but on_empty_tissue=proceed, gate to full image
            no_tissue = (
                (tissue_polys is None) or
                (hasattr(tissue_polys, "empty") and getattr(tissue_polys, "empty")) or
                (isinstance(tissue_polys, list) and len(tissue_polys) == 0)
            )
            if no_tissue:
                on_empty = str(cfg.get("on_empty_tissue", "proceed")).lower()
                if on_empty == "proceed":
                    print("[tissue] No polygons detected; proceeding with full-image fallback")
                    try:
                        if is_input_wsi and 'slide' in locals() and slide is not None:
                            W0, H0 = slide.level_dimensions[0]
                            src_mpp_use = source_mpp if source_mpp is not None else locals().get('src_mpp_from_read', None)
                            mpp_scale = None
                            if src_mpp_use is not None and eff_mpp is not None:
                                try:
                                    mpp_scale = float(eff_mpp) / float(src_mpp_use)
                                except Exception:
                                    mpp_scale = None
                            width_v = float(W0) / float(mpp_scale) if mpp_scale is not None else float(W0)
                            height_v = float(H0) / float(mpp_scale) if mpp_scale is not None else float(H0)
                            tissue_polys = [sg.box(0.0, 0.0, width_v, height_v)]
                        else:
                            # Regular image
                            if '_read_he_image' in globals() and _read_he_image is not None:
                                im0 = _read_he_image(str(input_path))
                                w0, h0 = im0.size
                                tissue_polys = [sg.box(0.0, 0.0, float(w0), float(h0))]
                            else:
                                tissue_polys = None
                    except Exception:
                        tissue_polys = None
                else:
                    print("[tissue] No polygons detected; on_empty_tissue policy prevents fallback")
                    # Enforce policy: either error out; 'skip' is deprecated
                    if on_empty == "error":
                        raise RuntimeError("No tissue detected and on_empty_tissue=error; aborting.")
                    elif on_empty == "skip":
                        raise RuntimeError("No tissue detected and on_empty_tissue=skip (deprecated); use 'proceed' or 'error'.")
                    tissue_polys = None
        except Exception as _e_tissue:
            print(f"[tissue] Detection failed: {_e_tissue}", file=sys.stderr)
            # If user requested error-on-empty-tissue, propagate failure to abort
            if str(cfg.get("on_empty_tissue", "proceed")).lower() == "error":
                raise
            tissue_polys = None
    elif bool(cfg.get("run_tissue_detection", False)) and not HAS_TISSUEPY:
        print("[tissue] Skipping detection: tissue module unavailable")

    coords, embeds, token_dim = None, None, None
    if not bool(cfg.get("run_embedding_extraction", True)):
        print("[embed] Embedding extraction disabled by config.")
    elif extractEmbeddingCamel is None:
        print("[embed] extractEmbeddingCamel not available; ensure timm is installed.", file=sys.stderr)
    else:
        try:
            result = extractEmbeddingCamel(
                inputObj=cfg["input"],
                modelName=cfg["model_id"],
                inputSize=tuple(cfg["input_size"]),
                mean=list(cfg["mean"]),
                std=list(cfg["std"]),
                patchSize=int(cfg["patch_size"]),
                stride=int(cfg["stride"]),
                outputType=cfg["output_type"],
                batchSize=int(cfg["batch_size"]),
                device=str(cfg.get("device", "auto")),
                streamWrite=bool(cfg.get("stream_write", True)),
                targetMpp=cfg.get("target_mpp", None),
                sourceMpp=cfg.get("source_mpp", None),
                tissuePolys=tissue_polys,
                filterToTissue=bool(cfg.get("filter_to_tissue", True)),
                timmKwargs=dict(cfg.get("timm_kwargs", {})),
                slideId=str(cfg.get("slide_id")) if cfg.get("slide_id") else None,
            )
            coords = result.get("coords")
            embeds = result.get("embeddings")
            token_dim = result.get("token_dim")
        except Exception as _e:
            print("[embed] Embedding extraction failed:", repr(_e), file=sys.stderr)
            traceback.print_exc()

    if bool(cfg.get("run_embedding_extraction", True)) and coords is not None:
        # Determine requested types (string or list)
        ot = cfg.get("output_type")
        req_types = [str(ot).lower()] if isinstance(ot, str) else [str(x).lower() for x in ot]
        # If single-output compatibility provided in result["embeddings"], treat as that one
        if embeds is not None:
            t = req_types[0]
            try:
                embeds_np = embeds.compute() if hasattr(embeds, "compute") else embeds
            except Exception:
                embeds_np = embeds
            # merge meta
            meta_local = dict(meta)
            meta_local.update({
                "n_patches": int(getattr(embeds_np, "shape", [0, 0])[0]),
                "embedding_dim": int(getattr(embeds_np, "shape", [0, 0])[1]) if hasattr(embeds_np, "shape") else None,
            })
            cfg_t = dict(cfg)
            cfg_t["output_type"] = t
            output_path = str(cfg.get("output") or "").strip()
            if output_path:
                _save_or_update_spatialdata(coords, embeds_np, output_path, meta_local, cfg_t, 
                                          token_dim=token_dim if t == "token" else None, 
                                          tissue_polys=tissue_polys)
                wrote_any = True
        else:
            # Multi-output path: write each requested type
            for t in req_types:
                emb_key = "embeddings_patch" if t == "patch" else "embeddings_token"
                embeds_cur = result.get(emb_key)
                if embeds_cur is None:
                    continue
                try:
                    embeds_np = embeds_cur.compute() if hasattr(embeds_cur, "compute") else embeds_cur
                except Exception:
                    embeds_np = embeds_cur
                meta_local = dict(meta)
                meta_local.update({
                    "n_patches": int(getattr(embeds_np, "shape", [0, 0])[0]),
                    "embedding_dim": int(getattr(embeds_np, "shape", [0, 0])[1]) if hasattr(embeds_np, "shape") else None,
                })
                cfg_t = dict(cfg)
                cfg_t["output_type"] = t
                output_path = str(cfg.get("output") or "").strip()
                if output_path:
                    # Just append the table for subsequent types
                    adata, table_name = build_embedding_table(coords, embeds_np, meta_local, cfg_t, token_dim=result.get("token_dim") if t == "token" else None)
                    update_spatialdata_table(adata, output_path, table_name)
                    print(f"Appended table '{table_name}' to SpatialData at: {output_path}")
                    wrote_any = True  # Mark that we've written output
                    # Also add/update shapes for this output type (tokens or tiles)
                    try:
                        tile_polys = build_tile_shapes(coords, embeds_np, meta_local, cfg_t, token_dim=result.get("token_dim") if t == "token" else None)
                        shape_region = "tokens" if t == "token" else "tiles"
                        update_spatialdata_shapes(tile_polys, output_path, meta_local, cfg_t, region_name=shape_region)
                        print(f"Updated SpatialData shapes at: {output_path} (region '{shape_region}')")
                    except Exception as _e_shapes:
                        print(f"[warning] Could not update {t} shapes on append: {_e_shapes}")

    if bool(cfg.get("run_tissue_detection", False)) and (tissue_polys is not None):
        # Normalize list or GeoDataFrame and check non-empty
        try:
            if tissue_geoms_list is None:
                if hasattr(tissue_polys, "geometry"):
                    tissue_geoms_list = list(getattr(tissue_polys, "geometry"))
                else:
                    tissue_geoms_list = list(tissue_polys)
        except Exception:
            tissue_geoms_list = None
        has_any = (tissue_geoms_list is not None and len(tissue_geoms_list) > 0)
        if has_any:
            # If shapes were not written before segmentation, write/update now
            if not tissue_written_pre_seg:
                out_target = str(cfg.get("output") or "").strip()
                if out_target:
                    if wrote_any:
                        update_spatialdata_shapes(tissue_geoms_list, out_target, meta, cfg)
                        print(f"Updated SpatialData shapes at: {out_target} (region 'tissue')")
                    else:
                        save_tissue_spatialdata(tissue_geoms_list, out_target, meta, cfg)
                        print(f"Saved SpatialData (tissue-only) to: {out_target}")
                        wrote_any = True

    # --- Optional segmentation stage (integrated) ---
    if bool(cfg.get("run_segmentation", False)):
        if not HAS_SEGMENT or segmentImage is None:
            print("[segment] run_segmentation=true but segmentImage not available; install cellpose.", file=sys.stderr)
        else:
            try:
                seg_out = cfg.get("segmentation_output_path")
                if not seg_out:
                    try:
                        p_in = Path(str(cfg.get("input")))
                        seg_out = str(p_in.parent / f"{p_in.stem}_segMask.tiff")
                    except Exception:
                        seg_out = None
                if not seg_out:
                    print("[segment] No segmentation output path resolved.", file=sys.stderr)
                else:
                    # If we wrote SpatialData earlier, pass it for tissue gating
                    t_zarr = None
                    try:
                        out_target = str(cfg.get("output") or "").strip()
                        if out_target:
                            t_zarr = out_target
                    except Exception:
                        t_zarr = None
                    segmentImage(
                        inputPath=str(cfg.get("input")),
                        outputPath=str(seg_out),
                        patchSize=int(cfg.get("segmentation_patch_size", cfg.get("patch_size", 512))),
                        stride=int(cfg.get("segmentation_stride", cfg.get("stride", 512))),
                        pretrainedModel=str(cfg.get("segmentation_pretrained_model", "cpsam")),
                        diameter=cfg.get("segmentation_diameter"),
                        tissueThreshold=float(cfg.get("segmentation_tissue_threshold", 240.0)),
                        maxPatches=cfg.get("segmentation_max_patches"),
                        batchSize=int(cfg.get("segmentation_batch_size", cfg.get("batch_size", 8))),
                        progressInterval=int(cfg.get("segmentation_progress_interval", 100)),
                        verbose=bool(cfg.get("segmentation_verbose", False)),
                        tileOverlap=float(cfg.get("segmentation_tile_overlap", 0.1)),
                        mergeOverlaps=bool(cfg.get("segmentation_merge_overlaps", False)),
                        cellprobThreshold=cfg.get("segmentation_cellprob_threshold"),
                        flowThreshold=cfg.get("segmentation_flow_threshold"),
                        normalize=bool(cfg.get("segmentation_normalize", True)),
                        modelType=str(cfg.get("segmentation_model_type", "cyto3")),
                        tissueZarr=t_zarr,
                        tissueMaskPath=None,
                        regionName="tissue",
                    )
                    print(f"[segment] Saved segmentation mask to: {seg_out}")
                    cfg["segmentation_output_path"] = seg_out
            except Exception as _e:
                print("[segment] Segmentation failed:", repr(_e), file=sys.stderr)
                traceback.print_exc()

    # Optional spatial feature table stage
    if bool(cfg.get("run_spatial_feature_table", False)):
        if not HAS_SFT or spatialFeatureTable is None:
            print("[spatial] run_spatial_feature_table=true but spatialFeatureTable not available; install dependencies.", file=sys.stderr)
        else:
            try:
                in_path = str(cfg.get("sft_input_path") or cfg.get("input"))
                # Resolve default mask path: prefer explicit sft_mask_path, else segmentation output
                mask_path = cfg.get("sft_mask_path")
                if not mask_path:
                    candidates = []
                    seg_out = cfg.get("segmentation_output_path")
                    if seg_out:
                        candidates.append(str(seg_out))
                    try:
                        p_he = Path(str(cfg.get("input")))
                        candidates.append(str(p_he.parent / f"{p_he.stem}_segMask.tiff"))
                    except Exception:
                        pass
                    try:
                        p_sft = Path(in_path)
                        candidates.append(str(p_sft.parent / f"{p_sft.stem}_segMask.tiff"))
                    except Exception:
                        pass
                    mask_path = next(
                        (c for c in candidates if c and Path(c).exists()),
                        candidates[-1] if candidates else None,
                    )
                # If segmentation was not run and file missing, prompt override
                try:
                    if not mask_path or not Path(str(mask_path)).exists():
                        print(f"[spatial] Mask not found at '{mask_path}'. Set 'spatialMaskPath' to the segmentation mask path.", file=sys.stderr)
                except Exception:
                    pass
                # Resolve SpatialData target: prefer explicit sft_output_path, fallback to output
                out_sp = cfg.get("sft_output_path") or cfg.get("output")
                zarr_dir = spatialFeatureTable(
                    inputPath=in_path,
                    maskPath=str(mask_path),
                    markersCsvPath=cfg.get("sft_markers_csv_path"),
                    stat=str(cfg.get("sft_stat", "mean")),
                    outputPath=str(out_sp) if out_sp else None,
                    outputDtype=str(cfg.get("sft_output_dtype", "float16")),
                    verbose=bool(cfg.get("sft_verbose", False)),
                    addSegmentationShapes=bool(cfg.get("sft_add_segmentation_shapes", True)),
                )
                print(f"Saved SpatialData table: {zarr_dir}")
                wrote_any = True
                # Optionally run token→cell mapping to attach per-token cell index
                if bool(cfg.get("sft_run_token_cell_mapper", True)):
                    if not HAS_TCM or tokenCellMapper is None:
                        print("[token-cell] tokenCellMapper not available; install dependencies.", file=sys.stderr)
                    else:
                        try:
                            tokenCellMapper(
                                sdata=str(zarr_dir),
                                tokenShapesKey="tokens",
                                cellShapesKey="segmentation_mask",
                                outColumn="cell_index",
                                outPath=str(zarr_dir),
                            )
                            print("Updated SpatialData with token→cell mapping.")
                        except Exception as _e:
                            print("[token-cell] tokenCellMapper failed:", repr(_e), file=sys.stderr)
                            traceback.print_exc()
            except Exception as _e:
                print("[spatial] spatialFeatureTable failed:", repr(_e), file=sys.stderr)
                traceback.print_exc()

    if not wrote_any:
        print("No output written: neither embeddings nor tissue shapes were produced.")
    return {
        "coords": coords,
        "embeddings": embeds,
        "token_dim": token_dim,
        "tissue_polys": tissue_polys,
    }

def embedImageFromDict(cfgCamel: Dict[str, Any]) -> Optional[Dict[str, Any]]:
    """Notebook-friendly entrypoint: pass a camelCase config dict."""
    return runEmbeddingPipeline(cfgCamel)


def embedImageFromYaml(
    config_path: str,
    image_path: Optional[Union[str, "Image.Image"]] = None,
):
    """
    Convenience wrapper for notebooks: load YAML config and optionally override the input image path.
    - config_path: path to a YAML configuration file
    - image_path: if provided, overrides cfg['input']
    Returns whatever runEmbeddingPipeline returns (None or a result dict depending on pipeline).
    """
    cfg = load_config(config_path)
    if image_path is not None:
        cfg["input"] = image_path
    return runEmbeddingPipeline(cfg)

if __name__ == "__main__":
    main()
