"""
spatial_writer.py — Modular writers for embeddings and tissue shapes
-------------------------------------------------------------------
Provides reusable functions to create and save SpatialData datasets and
fallback AnnData tables. Handles patch-level and token-level outputs,
coordinate frame export to WSI level-0 pixels, and optional tissue shapes.

Functions:
- save_as_spatialdata(coords_xy, embeddings, out_path, meta, cfg, token_dim=None, tissue_polys=None)
- save_as_h5ad_table(coords_xy, embeddings, out_path, meta, cfg, token_dim=None)
- save_tissue_spatialdata(tissue_polys, out_path, meta, cfg)
 - build_embedding_table(coords_xy, embeddings, meta, cfg, token_dim=None) -> (adata, table_name)
 - update_spatialdata_shapes(tissue_polys, zarr_path, meta, cfg, region_name="tissue")
 - build_tile_shapes(coords_xy, embeddings, meta, cfg, token_dim=None) -> List[Any]
"""

from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import numpy as np

# Optional deps and flags
HAS_ANNDATA = False
try:
    import anndata as ad  # type: ignore
    HAS_ANNDATA = True
except Exception:
    HAS_ANNDATA = False

HAS_SPATIALDATA = False
try:
    import spatialdata as sdata  # type: ignore
    from spatialdata import SpatialData  # type: ignore
    HAS_SPATIALDATA = True
except Exception:
    HAS_SPATIALDATA = False

HAS_DASK = False
try:
    import dask.array as da  # type: ignore
    HAS_DASK = True
except Exception:
    HAS_DASK = False

HAS_SHAPELY = False
try:
    import shapely.geometry as sg  # type: ignore
    import geopandas as gpd  # type: ignore
    from spatialdata.models import ShapesModel, TableModel  # type: ignore
    HAS_SHAPELY = True
except Exception:
    HAS_SHAPELY = False


def build_embedding_table(
    coords_xy: np.ndarray,
    embeddings: np.ndarray,
    meta: Dict[str, Any],
    cfg: Dict[str, Any],
    token_dim: Optional[int] = None,
):
    """Construct an AnnData table for embeddings and return it with a table name.

    Handles patch- vs token-level layout, coordinate frame export, optional quantization,
    and attaches tile/spec and provenance metadata. The table name is derived from the
    model key: "<model_key>_tiles".

    Returns: (adata, table_name)
    """
    if not HAS_ANNDATA:
        raise RuntimeError("AnnData not available inside build_embedding_table.")

    # derive model key and table name (tokens vs tiles)
    model_key = _model_key_for_table(str(meta.get("model_name", "model")))
    suffix = "tokens" if str(cfg.get("output_type", "patch")) == "token" else "tiles"
    table_name = f"{model_key}_{suffix}"

    output_type = str(cfg.get("output_type", "patch"))
    patch_px: int = int(cfg["patch_size"])
    n_tiles = int(coords_xy.shape[0])

    # token vs patch layout
    if output_type == "token" and token_dim is not None:
        flat_dim = int(embeddings.shape[1])
        tokens_total = flat_dim // int(token_dim)
        tokens_side = int(np.floor(np.sqrt(tokens_total)))
        tokens_use = tokens_side * tokens_side
        drop_count = tokens_total - tokens_use

        embeds3d = embeddings.reshape(n_tiles, tokens_total, int(token_dim))
        if drop_count > 0:
            embeds3d = embeds3d[:, drop_count:, :]
        embeds3d = embeds3d[:, :tokens_use, :]
        embeds2d = embeds3d.reshape(n_tiles * tokens_use, int(token_dim))

        token_size = patch_px / tokens_side if tokens_side > 0 else patch_px
        coords_list: List[Tuple[float, float]] = []
        obs_labels: List[str] = []
        for tile_idx, (x, y) in enumerate(coords_xy):
            for row in range(tokens_side):
                for col in range(tokens_side):
                    coords_list.append((x + col * token_size, y + row * token_size))
                    obs_labels.append(f"tile_{tile_idx}_tok_{row}_{col}")
        coords_out = np.asarray(coords_list, dtype=np.float32)
    else:
        embeds2d = embeddings
        coords_out = coords_xy.astype(np.float32, copy=False)
        obs_labels = [f"tile_{i}" for i in range(n_tiles)]
        tokens_side = 1
        tokens_use = 1
        drop_count = 0

    # coordinate frame export
    export_frame = str(cfg.get("export_coordinate_frame", "level0")).lower()
    eff_mpp = meta.get("effective_mpp", None)
    source_mpp = meta.get("source_mpp", None)
    mpp_scale: Optional[float] = None
    if eff_mpp is not None and source_mpp is not None:
        try:
            # Convert virtual pixel distances at eff_mpp to level-0 pixels: d0 = d * (eff_mpp / source_mpp)
            mpp_scale = float(eff_mpp) / float(source_mpp)
        except Exception:
            mpp_scale = None
    if export_frame == "level0" and mpp_scale is not None:
        coords_out = (coords_out * float(mpp_scale)).astype(np.float32, copy=False)
        meta["coordinate_frame"] = "level-0 pixels"
    else:
        meta["coordinate_frame"] = meta.get("coordinate_frame", "target-mpp virtual pixels")

    # dtype casting or quantization
    qmeta: Dict[str, Any] = {}
    desired = str(cfg.get("output_dtype", "float16")).lower()
    if desired in ("float16", "float32"):
        target = np.float16 if desired == "float16" else np.float32
        try:
            embeds2d = embeds2d.astype(target)
        except Exception:
            embeds2d, qmeta = _apply_output_dtype(np.asarray(embeds2d), cfg)
    elif desired == "uint8":
        try:
            embeds2d = embeds2d.astype(np.float16)
        except Exception:
            embeds2d, qmeta = _apply_output_dtype(np.asarray(embeds2d), {**cfg, "output_dtype": "float16"})
        cfg["output_dtype"] = "float16"

    # AnnData assembly
    adata = ad.AnnData(X=embeds2d)
    adata.obsm["spatial"] = coords_out
    instance_ids = np.arange(adata.n_obs, dtype=np.int32)
    adata.obs_names = instance_ids.astype(str)
    adata.obs["obs_label"] = np.array(obs_labels, dtype=object)
    # Add slide_id to obs
    try:
        _slide_id = str(meta.get("slide_id") or Path(str(meta.get("input_path", ""))).stem)
    except Exception:
        _slide_id = str(meta.get("slide_id", ""))
    adata.obs["slide_id"] = np.array([_slide_id] * adata.n_obs, dtype=object)
    if output_type == "token" and token_dim is not None:
        n_tiles_local = n_tiles
        adata.obs["tile_id"] = np.repeat(np.arange(n_tiles_local, dtype=np.int32), tokens_use)
        adata.obs["token_row"] = np.tile(np.repeat(np.arange(tokens_side, dtype=np.int32), tokens_side), n_tiles_local)
        adata.obs["token_col"] = np.tile(np.arange(tokens_side, dtype=np.int32), n_tiles_local * tokens_side)
        adata.uns["token_dim"] = int(token_dim)
        adata.uns["tokens_side"] = int(tokens_side)
        adata.uns["drop_tokens"] = int(drop_count)
    else:
        adata.obs["tile_id"] = np.arange(n_tiles, dtype=np.int32)

    adata.uns["output_type"] = output_type
    adata.uns["output_dtype"] = cfg.get("output_dtype", "float16")
    if qmeta:
        adata.uns["quantization"] = _sanitize_meta_for_storage(qmeta)
    tile_spec = _build_tilespec_meta(
        meta.get("effective_mpp", 1.0),
        int(cfg["patch_size"]),
        int(cfg["stride"]),
        meta.get("pyramid_level", 0),
        meta.get("level_downsample", 1.0),
    )
    adata.uns["tile_spec"] = _sanitize_meta_for_storage(tile_spec)
    slide_props = _build_slideprops_meta(
        meta.get("source_mpp", None),
        meta.get("input_path", ""),
        meta.get("image_size_resampled", None),
    )
    adata.uns["slide_properties"] = _sanitize_meta_for_storage(slide_props)
    provenance = dict(meta)
    provenance.pop("image_size_resampled", None)
    adata.uns["provenance"] = _sanitize_meta_for_storage(provenance)

    return adata, table_name


def write_spatialdata_table(adata: "ad.AnnData", out_path: str, table_name: str) -> None:
    """Write a SpatialData dataset containing a single table.

    Creates a SpatialData with only `tables={table_name: adata}` and writes it
    to `out_path` using the available writer in the installed spatialdata version.
    """
    if not HAS_SPATIALDATA:
        raise RuntimeError("SpatialData not available inside write_spatialdata_table.")
    sd = SpatialData(tables={table_name: adata})
    outp = Path(out_path)
    # Pre-clear any existing non-Zarr path to honor overwrite semantics
    try:
        from shutil import rmtree
        if outp.exists():
            if outp.is_dir():
                rmtree(outp, ignore_errors=True)
            else:
                outp.unlink(missing_ok=True)
    except Exception:
        pass
    if hasattr(sd, "write"):
        sd.write(out_path, overwrite=True)
    else:
        raise RuntimeError("SpatialData.write(...) not available; please upgrade spatialdata to >=0.5.0.")


def update_spatialdata_table(adata: "ad.AnnData", zarr_path: str, table_name: str) -> None:
    """Update or add a table in an existing SpatialData zarr in-place.

    Reads the existing dataset if possible, replaces or adds the table under
    `table_name`, and writes back to the same zarr path. Preserves other layers
    (shapes, images, labels, points) when present.
    """
    if not HAS_SPATIALDATA:
        raise RuntimeError("SpatialData not available inside update_spatialdata_table.")
    existing = None
    try:
        if hasattr(sdata, "read_zarr"):
            existing = sdata.read_zarr(zarr_path)  # type: ignore
        elif hasattr(sdata, "io") and hasattr(sdata.io, "read_zarr"):
            existing = sdata.io.read_zarr(zarr_path)  # type: ignore
    except Exception:
        existing = None

    if existing is None:
        # If not readable, just write a new dataset at the same path
        write_spatialdata_table(adata, zarr_path, table_name)
        return

    tables = dict(getattr(existing, "tables", {}))
    tables[table_name] = adata
    shapes = getattr(existing, "shapes", None)
    images = getattr(existing, "images", None)
    labels_dict = getattr(existing, "labels", None)
    points = getattr(existing, "points", None)
    sd = SpatialData(
        shapes=shapes if shapes else None,
        images=images if images else None,
        labels=labels_dict if labels_dict else None,
        points=points if points else None,
        tables=tables,
    )
    if hasattr(sd, "write"):
        # Pre-clear any existing non-Zarr path to honor overwrite semantics
        try:
            zp = Path(zarr_path)
            if zp.exists():
                if zp.is_dir():
                    from shutil import rmtree
                    rmtree(zp, ignore_errors=True)
                else:
                    zp.unlink(missing_ok=True)
        except Exception:
            pass
        sd.write(zarr_path, overwrite=True)
    else:
        raise RuntimeError("SpatialData.write(...) not available; please upgrade spatialdata to >=0.5.0.")


def _model_key_for_table(name: str) -> str:
    name = str(name).lower().strip()
    if not name:
        return "model"
    # sanitize to simple key
    return "".join(c for c in name if c.isalnum() or c in ("_", "-"))


def _build_tilespec_meta(eff_mpp: float, patch_size: int, stride: int, level: int, downsample: float) -> dict:
    # Be robust to None values from non-WSI inputs
    try:
        eff_mpp_out = 1.0 if eff_mpp is None else float(eff_mpp)
    except Exception:
        eff_mpp_out = 1.0
    try:
        level_out = 0 if level is None else int(level)
    except Exception:
        level_out = 0
    try:
        downsample_out = 1.0 if downsample is None else float(downsample)
    except Exception:
        downsample_out = 1.0
    return {
        "effective_mpp": eff_mpp_out,
        "patch_size": int(patch_size),
        "stride": int(stride),
        "ops_level": level_out,
        "level_downsample": downsample_out,
    }


def _build_slideprops_meta(source_mpp, input_path, image_size_resampled=None) -> dict:
    return {
        "source_mpp": None if source_mpp is None else float(source_mpp),
        "input_path": str(input_path),
        "image_size_resampled": image_size_resampled,
    }


def _sanitize_meta_for_storage(meta: Dict[str, Any]) -> Dict[str, Any]:
    # Ensure JSON-friendly by converting numpy scalars
    out: Dict[str, Any] = {}
    for k, v in meta.items():
        if isinstance(v, (np.generic,)):
            out[k] = v.item()
        else:
            out[k] = v
    return out


def _apply_output_dtype(embeds2d: np.ndarray, cfg: Dict[str, Any]):
    desired = str(cfg.get("output_dtype", "float16")).lower()
    qmeta: Dict[str, Any] = {}
    if desired == "uint8":
        mins = embeds2d.min(axis=0)
        maxs = embeds2d.max(axis=0)
        ranges = np.maximum(maxs - mins, 1e-12)
        scaled = (embeds2d - mins) / ranges
        out = np.clip(np.round(scaled * 255.0), 0, 255).astype(np.uint8)
        qmeta = {
            "mode": "minmax_uint8",
            "feature_min": mins.astype(np.float32).tolist(),
            "feature_max": maxs.astype(np.float32).tolist(),
        }
    elif desired == "float32":
        out = embeds2d.astype(np.float32, copy=False)
    else:
        out = embeds2d.astype(np.float16, copy=False)
    return out, qmeta

def build_tile_shapes(
    coords_xy: np.ndarray,
    embeddings: np.ndarray,
    meta: Dict[str, Any],
    cfg: Dict[str, Any],
    token_dim: Optional[int] = None,
) -> List[Any]:
    """Construct shapely polygons for tiles or tokens based on coords and config.

    - For `output_type == 'patch'`, creates one square polygon per tile with side
      length equal to `patch_size`.
    - For `output_type == 'token'` and `token_dim` is provided, subdivides each
      patch into a `tokens_side x tokens_side` grid and constructs one square
      polygon per token with side length `patch_size / tokens_side`.

    Note: Polygons are built in the same coordinate frame as `coords_xy`
    (typically target-mpp virtual pixels). If export to level-0 is desired, call
    `update_spatialdata_shapes(..., region_name='tiles')` which will scale the
    polygons to the level-0 frame when configured.
    """
    if not HAS_SHAPELY:
        raise RuntimeError("Shapely/GeoPandas required for building tile/token shapes.")

    output_type = str(cfg.get("output_type", "patch"))
    patch_px: int = int(cfg["patch_size"])
    n_tiles = int(coords_xy.shape[0])

    if output_type == "token" and token_dim is not None:
        flat_dim = int(embeddings.shape[1])
        tokens_total = flat_dim // int(token_dim)
        tokens_side = int(np.floor(np.sqrt(tokens_total)))
        tokens_use = tokens_side * tokens_side
        # compute token top-left coords within each patch
        token_size = patch_px / tokens_side if tokens_side > 0 else patch_px
        coords_list: List[Tuple[float, float]] = []
        for tile_idx, (x, y) in enumerate(coords_xy):
            # only generate a square grid of tokens_use; extra leading tokens are ignored
            for row in range(tokens_side):
                for col in range(tokens_side):
                    coords_list.append((float(x) + col * token_size, float(y) + row * token_size))
        coords_out = np.asarray(coords_list, dtype=np.float32)
        side_len = float(token_size)
    else:
        # patch-level shapes
        coords_out = coords_xy.astype(np.float32, copy=False)
        side_len = float(patch_px)

    polys: List[Any] = []
    for (x0, y0) in coords_out:
        polys.append(
            sg.Polygon([
                (float(x0), float(y0)),
                (float(x0) + side_len, float(y0)),
                (float(x0) + side_len, float(y0) + side_len),
                (float(x0), float(y0) + side_len),
            ])
        )
    return polys


def save_as_spatialdata(
    coords_xy: np.ndarray,
    embeddings: np.ndarray,
    out_path: str,
    meta: Dict[str, Any],
    cfg: Dict[str, Any],
    token_dim: Optional[int] = None,
    tissue_polys: Optional[List[Any]] = None,
):
    if not HAS_SPATIALDATA:
        raise RuntimeError("SpatialData not available inside save_as_spatialdata.")

    # make a local mutable copy of meta
    meta = dict(meta)
    # derive model key and table name (tokens vs tiles)
    model_key = _model_key_for_table(meta.get("model_name", "model"))
    suffix = "tokens" if str(cfg.get("output_type", "patch")) == "token" else "tiles"
    table_name = f"{model_key}_{suffix}"

    # config-driven sizing
    output_type = str(cfg.get("output_type", "patch"))
    patch_px: int = int(cfg["patch_size"])
    n_tiles = int(coords_xy.shape[0])

    # ----- build embeddings + coords + labels (token vs patch) -----
    if output_type == "token" and token_dim is not None:
        flat_dim = int(embeddings.shape[1])
        tokens_total = flat_dim // int(token_dim)
        tokens_side = int(np.floor(np.sqrt(tokens_total)))
        tokens_use = tokens_side * tokens_side
        drop_count = tokens_total - tokens_use

        embeds3d = embeddings.reshape(n_tiles, tokens_total, int(token_dim))
        if drop_count > 0:
            embeds3d = embeds3d[:, drop_count:, :]
        embeds3d = embeds3d[:, :tokens_use, :]

        embeds2d = embeds3d.reshape(n_tiles * tokens_use, int(token_dim))

        token_size = patch_px / tokens_side if tokens_side > 0 else patch_px
        coords_list: List[Tuple[float, float]] = []
        obs_labels: List[str] = []
        for tile_idx, (x, y) in enumerate(coords_xy):
            for row in range(tokens_side):
                for col in range(tokens_side):
                    coords_list.append((x + col * token_size, y + row * token_size))
                    obs_labels.append(f"tile_{tile_idx}_tok_{row}_{col}")
        coords_out = np.asarray(coords_list, dtype=np.float32)
    else:
        embeds2d = embeddings
        coords_out = coords_xy.astype(np.float32, copy=False)
        obs_labels = [f"tile_{i}" for i in range(n_tiles)]
        tokens_side = 1
        tokens_use = 1
        drop_count = 0

    # ----- coordinate frame export (WSI level-0 vs target-mpp) -----
    export_frame = str(cfg.get("export_coordinate_frame", "level0")).lower()
    eff_mpp = meta.get("effective_mpp", None)
    source_mpp = meta.get("source_mpp", None)
    mpp_scale: Optional[float] = None
    if eff_mpp is not None and source_mpp is not None:
        try:
            mpp_scale = float(eff_mpp) / float(source_mpp)
        except Exception:
            mpp_scale = None
    if export_frame == "level0" and mpp_scale is not None:
        coords_out = (coords_out * float(mpp_scale)).astype(np.float32, copy=False)
        meta["coordinate_frame"] = "level-0 pixels"
    else:
        meta["coordinate_frame"] = meta.get("coordinate_frame", "target-mpp virtual pixels")

    # ----- dtype casting or quantization -----
    qmeta: Dict[str, Any] = {}
    desired = str(cfg.get("output_dtype", "float16")).lower()
    if desired in ("float16", "float32"):
        target = np.float16 if desired == "float16" else np.float32
        try:
            embeds2d = embeds2d.astype(target)
        except Exception:
            embeds2d, qmeta = _apply_output_dtype(np.asarray(embeds2d), cfg)
    elif desired == "uint8":
        # Streaming token/patch with global quantization is complex; fallback to float16
        try:
            embeds2d = embeds2d.astype(np.float16)
        except Exception:
            embeds2d, qmeta = _apply_output_dtype(np.asarray(embeds2d), {**cfg, "output_dtype": "float16"})
        cfg["output_dtype"] = "float16"

    # ----- chunk if large -----
    if HAS_DASK:
        try:
            is_dask = hasattr(embeds2d, "chunks")
        except Exception:
            is_dask = False
        if is_dask:
            X_use = embeds2d
        else:
            n_rows, n_features = embeds2d.shape
            bytes_per_elem = np.dtype(embeds2d.dtype).itemsize
            target_bytes = 128 * 1024 * 1024
            chunk_rows = max(1, target_bytes // (n_features * bytes_per_elem))
            chunk_rows = min(chunk_rows, n_rows)
            X_use = da.from_array(embeds2d, chunks=(chunk_rows, n_features))
    else:
        X_use = embeds2d

    # ----- AnnData table -----
    adata = ad.AnnData(X=X_use)
    adata.obsm["spatial"] = coords_out
    instance_ids = np.arange(adata.n_obs, dtype=np.int32)
    adata.obs_names = instance_ids.astype(str)
    adata.obs["obs_label"] = np.array(obs_labels, dtype=object)
    # Add slide_id to obs
    try:
        _slide_id = str(meta.get("slide_id") or Path(str(meta.get("input_path", ""))).stem)
    except Exception:
        _slide_id = str(meta.get("slide_id", ""))
    adata.obs["slide_id"] = np.array([_slide_id] * adata.n_obs, dtype=object)
    if output_type == "token" and token_dim is not None:
        n_tiles_local = n_tiles
        adata.obs["tile_id"] = np.repeat(np.arange(n_tiles_local, dtype=np.int32), tokens_use)
        adata.obs["token_row"] = np.tile(np.repeat(np.arange(tokens_side, dtype=np.int32), tokens_side), n_tiles_local)
        adata.obs["token_col"] = np.tile(np.arange(tokens_side, dtype=np.int32), n_tiles_local * tokens_side)
        adata.uns["token_dim"] = int(token_dim)
        adata.uns["tokens_side"] = int(tokens_side)
        adata.uns["drop_tokens"] = int(drop_count)
    else:
        adata.obs["tile_id"] = np.arange(n_tiles, dtype=np.int32)

    adata.uns["output_type"] = output_type
    adata.uns["output_dtype"] = cfg.get("output_dtype", "float16")
    if qmeta:
        adata.uns["quantization"] = _sanitize_meta_for_storage(qmeta)
    tile_spec = _build_tilespec_meta(
        meta.get("effective_mpp", 1.0),
        int(cfg["patch_size"]),
        int(cfg["stride"]),
        meta.get("pyramid_level", 0),
        meta.get("level_downsample", 1.0),
    )
    adata.uns["tile_spec"] = _sanitize_meta_for_storage(tile_spec)
    slide_props = _build_slideprops_meta(
        meta.get("source_mpp", None),
        meta.get("input_path", ""),
        meta.get("image_size_resampled", None),
    )
    adata.uns["slide_properties"] = _sanitize_meta_for_storage(slide_props)
    provenance = dict(meta)
    provenance.pop("image_size_resampled", None)
    adata.uns["provenance"] = _sanitize_meta_for_storage(provenance)

    # ----- SpatialData with shapes -----
    if HAS_SHAPELY:
        # Name shapes region based on output type
        region_name = "tokens" if str(cfg.get("output_type", "patch")) == "token" else "tiles"
        adata.obs["region"] = region_name
        adata.obs["instance_id"] = instance_ids
        if output_type == "token" and token_dim is not None:
            side_len = patch_px / tokens_side if tokens_side > 0 else patch_px
        else:
            side_len = patch_px
        export_frame = str(cfg.get("export_coordinate_frame", "level0")).lower()
        eff_mpp = meta.get("effective_mpp", None)
        source_mpp = meta.get("source_mpp", None)
        mpp_scale: Optional[float] = None
        if eff_mpp is not None and source_mpp is not None:
            try:
                mpp_scale = float(eff_mpp) / float(source_mpp)
            except Exception:
                mpp_scale = None
        if export_frame == "level0" and mpp_scale is not None:
            side_len = float(side_len) * float(mpp_scale)

        polys = []
        for (x0, y0) in coords_out:
            polys.append(
                sg.Polygon([(float(x0), float(y0)),
                            (float(x0) + side_len, float(y0)),
                            (float(x0) + side_len, float(y0) + side_len),
                            (float(x0), float(y0) + side_len)])
            )

        shapes_index = adata.obs_names.copy()
        gdf = gpd.GeoDataFrame(
            {
                "geometry": polys,
                "region": [region_name] * len(polys),
                "instance_id": instance_ids,
            },
            geometry="geometry",
            index=shapes_index,
        )
        # Add slide_id to shapes gdf
        try:
            _slide_id = str(meta.get("slide_id") or Path(str(meta.get("input_path", ""))).stem)
        except Exception:
            _slide_id = str(meta.get("slide_id", ""))
        gdf["slide_id"] = _slide_id

        # Add patch_id for tiles, and patch_id/token_id for tokens
        if str(region_name).lower() == "tiles":
            gdf["patch_id"] = np.arange(len(polys), dtype=np.int32)
        elif str(region_name).lower() == "tokens":
            n_tiles_local = n_tiles
            tokens_per_tile = int(tokens_use)
            idx = np.arange(len(polys), dtype=np.int32)
            patch_id = np.floor_divide(idx, tokens_per_tile)
            patch_id = np.clip(patch_id, 0, max(0, n_tiles_local - 1))
            token_id = np.mod(idx, tokens_per_tile)
            gdf["patch_id"] = patch_id
            gdf["token_id"] = token_id

        shapes_layer = ShapesModel.parse(gdf)
        table_layer = TableModel.parse(
            adata,
            region=region_name,
            region_key="region",
            instance_key="instance_id",
        )
        shapes_dict = {region_name: shapes_layer}
        # If tissue polygons are provided, add them as a separate shapes layer
        if tissue_polys:
            t_region = "tissue"
            t_index = [f"tissue_{i}" for i in range(len(tissue_polys))]
            tissue_geoms = tissue_polys
            if export_frame == "level0" and mpp_scale is not None:
                try:
                    from shapely.affinity import scale as _scale
                    tissue_geoms = [_scale(p, xfact=float(mpp_scale), yfact=float(mpp_scale), origin=(0, 0)) for p in tissue_polys]
                except Exception:
                    tissue_geoms = tissue_polys
            tgdf = gpd.GeoDataFrame(
                {
                    "geometry": tissue_geoms,
                    "region": [t_region] * len(tissue_geoms),
                    "instance_id": np.arange(len(tissue_geoms), dtype=np.int32),
                },
                geometry="geometry",
                index=t_index,
            )
            # Add slide_id to tissue shapes gdf
            try:
                _slide_id = str(meta.get("slide_id") or Path(str(meta.get("input_path", ""))).stem)
            except Exception:
                _slide_id = str(meta.get("slide_id", ""))
            tgdf["slide_id"] = _slide_id
            # Parse and add tissue shapes layer
            t_shapes_layer = ShapesModel.parse(tgdf)
            shapes_dict[t_region] = t_shapes_layer
        # Build SpatialData with both shapes and the embeddings table
        sd = SpatialData(shapes=shapes_dict, tables={table_name: adata})
    else:
        sd = SpatialData(tables={table_name: adata})

    # write
    outp = Path(out_path)
    try:
        from shutil import rmtree
        if outp.exists():
            if outp.is_dir():
                rmtree(outp, ignore_errors=True)
            else:
                outp.unlink(missing_ok=True)
    except Exception:
        pass

    if hasattr(sd, "write"):
        sd.write(out_path, overwrite=True)
    else:
        raise RuntimeError("SpatialData.write(...) not available; please upgrade spatialdata to >=0.5.0.")


def save_as_h5ad_table(
    coords_xy: np.ndarray,
    embeddings: np.ndarray,
    out_path: str,
    meta: Dict[str, Any],
    cfg: Dict[str, Any],
    token_dim: Optional[int] = None,
):
    if not HAS_ANNDATA:
        raise RuntimeError("AnnData not available; cannot write fallback .h5ad.")

    output_type = str(cfg.get("output_type", "patch"))
    patch_px: int = int(cfg["patch_size"])
    n_tiles = int(coords_xy.shape[0])

    if output_type == "token" and token_dim is not None:
        flat_dim = embeddings.shape[1]
        tokens_total = flat_dim // token_dim
        tokens_side = int(np.floor(np.sqrt(tokens_total)))
        tokens_use = tokens_side * tokens_side
        drop_count = tokens_total - tokens_use

        embeds3d = embeddings.reshape(n_tiles, tokens_total, token_dim)
        if drop_count > 0:
            embeds3d = embeds3d[:, drop_count:, :]
        embeds3d = embeds3d[:, :tokens_use, :]
        embeds2d = embeds3d.reshape(n_tiles * tokens_use, token_dim)

        token_size = patch_px / tokens_side if tokens_side > 0 else patch_px
        coords_list: List[Tuple[float, float]] = []
        obs_names: List[str] = []
        for tile_idx, (x, y) in enumerate(coords_xy):
            for row in range(tokens_side):
                for col in range(tokens_side):
                    coords_list.append((x + col * token_size, y + row * token_size))
                    obs_names.append(f"tile_{tile_idx}_tok_{row}_{col}")
        coords_out = np.asarray(coords_list, dtype=np.float32)
    else:
        embeds2d = embeddings
        coords_out = coords_xy.astype(np.float32, copy=False)
        obs_names = [f"tile_{i}" for i in range(n_tiles)]
        tokens_side = 1
        tokens_use = 1
        drop_count = 0

    embeds2d, qmeta = _apply_output_dtype(embeds2d, cfg)

    adata = ad.AnnData(X=embeds2d)
    adata.obs_names = obs_names
    adata.obsm["spatial"] = coords_out
    # Add slide_id to obs
    try:
        _slide_id = str(meta.get("slide_id") or Path(str(meta.get("input_path", ""))).stem)
    except Exception:
        _slide_id = str(meta.get("slide_id", ""))
    adata.obs["slide_id"] = np.array([_slide_id] * adata.n_obs, dtype=object)
    if output_type == "token" and token_dim is not None:
        adata.obs["tile_id"] = np.repeat(np.arange(n_tiles), tokens_use).astype(np.int32)
        adata.obs["token_row"] = np.tile(np.repeat(np.arange(tokens_side, dtype=np.int32), tokens_side), n_tiles)
        adata.obs["token_col"] = np.tile(np.arange(tokens_side, dtype=np.int32), n_tiles * tokens_side)
        adata.uns["token_dim"] = int(token_dim)
        adata.uns["tokens_side"] = int(tokens_side)
        adata.uns["drop_tokens"] = int(drop_count)
    else:
        adata.obs["tile_id"] = np.arange(n_tiles, dtype=np.int32)

    adata.uns["output_type"] = output_type
    adata.uns["output_dtype"] = cfg.get("output_dtype", "float16")
    if qmeta:
        adata.uns["quantization"] = _sanitize_meta_for_storage(qmeta)
    tile_spec = _build_tilespec_meta(
        meta.get("effective_mpp", 1.0),
        int(cfg["patch_size"]),
        int(cfg["stride"]),
        meta.get("pyramid_level", 0),
        meta.get("level_downsample", 1.0),
    )
    adata.uns["tile_spec"] = _sanitize_meta_for_storage(tile_spec)
    slide_props = _build_slideprops_meta(
        meta.get("source_mpp", None),
        meta.get("input_path", ""),
        meta.get("image_size_resampled", None),
    )
    adata.uns["slide_properties"] = _sanitize_meta_for_storage(slide_props)
    provenance = dict(meta)
    provenance.pop("image_size_resampled", None)
    adata.uns["provenance"] = _sanitize_meta_for_storage(provenance)

    p = Path(out_path)
    if p.suffix.lower() != ".h5ad":
        p = p.with_suffix(".h5ad")
    if p.exists():
        try:
            p.unlink()
        except Exception:
            pass
    adata.write_h5ad(p, compression="gzip")
    print(f"[fallback] Saved AnnData table to: {p}")


def save_tissue_spatialdata(tissue_polys: List[Any], out_path: str, meta: Dict[str, Any], cfg: Dict[str, Any]):
    if not HAS_SPATIALDATA:
        raise RuntimeError("SpatialData not available inside save_tissue_spatialdata.")
    if not HAS_SHAPELY:
        raise RuntimeError("Shapely/GeoPandas required for shapes writing.")
    t_region = "tissue"
    t_index = [f"tissue_{i}" for i in range(len(tissue_polys))]
    tgdf = gpd.GeoDataFrame(
        {
            "geometry": tissue_polys,
            "region": [t_region] * len(tissue_polys),
            "instance_id": np.arange(len(tissue_polys), dtype=np.int32),
        },
        geometry="geometry",
        index=t_index,
    )
    # Add slide_id to tissue shapes gdf
    try:
        _slide_id = str(meta.get("slide_id") or Path(str(meta.get("input_path", ""))).stem)
    except Exception:
        _slide_id = str(meta.get("slide_id", ""))
    tgdf["slide_id"] = _slide_id
    shapes_layer = ShapesModel.parse(tgdf)
    sd = SpatialData(shapes={t_region: shapes_layer})
    outp = Path(out_path)
    try:
        from shutil import rmtree
        if outp.exists():
            if outp.is_dir():
                rmtree(outp, ignore_errors=True)
            else:
                outp.unlink(missing_ok=True)
    except Exception:
        pass
    if hasattr(sd, "write"):
        sd.write(out_path, overwrite=True)
    else:
        raise RuntimeError("SpatialData.write(...) not available; please upgrade spatialdata to >=0.5.0.")
    # Lightweight provenance JSON for overlay tools
    try:
        import json
        prov = {
            "input_path": meta.get("input_path", ""),
            "source_mpp": meta.get("source_mpp", None),
            "effective_mpp": meta.get("effective_mpp", None),
            "image_size_resampled": meta.get("image_size_resampled", None),
            "coordinate_frame": meta.get("coordinate_frame", ""),
            "model_name": meta.get("model_name", ""),
            "config_path": meta.get("config_path", ""),
        }
        with open(outp / "provenance.json", "w", encoding="utf-8") as f:
            json.dump(prov, f, indent=2)
    except Exception:
        pass


def update_spatialdata_shapes(
    tissue_polys: List[Any],
    zarr_path: str,
    meta: Dict[str, Any],
    cfg: Dict[str, Any],
    region_name: str = "tissue",
):
    """Add or replace a shapes layer in an existing SpatialData zarr.

    Reads the SpatialData dataset when possible, updates `shapes[region_name]`
    with the provided polygons, and writes back to the same path. Preserves other
    layers (tables, images, labels, points) when present.
    """
    if not HAS_SPATIALDATA:
        raise RuntimeError("SpatialData not available inside update_spatialdata_shapes.")
    if not HAS_SHAPELY:
        raise RuntimeError("Shapely/GeoPandas required for shapes updating.")

    # Optional scaling to level-0 coordinate frame to match export behavior
    export_frame = str(cfg.get("export_coordinate_frame", "level0")).lower()
    eff_mpp = meta.get("effective_mpp", None)
    source_mpp = meta.get("source_mpp", None)
    mpp_scale: Optional[float] = None
    if eff_mpp is not None and source_mpp is not None:
        try:
            mpp_scale = float(eff_mpp) / float(source_mpp)
        except Exception:
            mpp_scale = None
    tissue_geoms = tissue_polys
    if export_frame == "level0" and mpp_scale is not None:
        try:
            from shapely.affinity import scale as _scale
            tissue_geoms = [_scale(p, xfact=float(mpp_scale), yfact=float(mpp_scale), origin=(0, 0)) for p in tissue_polys]
        except Exception:
            tissue_geoms = tissue_polys

    # Ensure index alignment with embeddings/table rows for tile/token shapes.
    # Use numeric string indices 0..n-1 for regions 'tiles' and 'tokens';
    # keep prefixed indices for other regions (e.g., 'tissue_0', 'tissue_1', ...).
    if str(region_name).lower() in ("tiles", "tokens"):
        index = [str(i) for i in range(len(tissue_geoms))]
    else:
        index = [f"{region_name}_{i}" for i in range(len(tissue_geoms))]
    tgdf = gpd.GeoDataFrame(
        {
            "geometry": tissue_geoms,
            "region": [region_name] * len(tissue_geoms),
            "instance_id": np.arange(len(tissue_geoms), dtype=np.int32),
        },
        geometry="geometry",
        index=index,
    )
    # Add slide_id to shapes update gdf
    try:
        _slide_id = str(meta.get("slide_id") or Path(str(meta.get("input_path", ""))).stem)
    except Exception:
        _slide_id = str(meta.get("slide_id", ""))
    tgdf["slide_id"] = _slide_id

    # Attach stable patch_id to tiles and patch_id/token_id to tokens for alignment
    if str(region_name).lower() == "tiles":
        tgdf["patch_id"] = np.arange(len(tissue_geoms), dtype=np.int32)
    elif str(region_name).lower() == "tokens":
        n_tiles = int(meta.get("n_patches", 0))
        if n_tiles <= 0:
            n_tiles = 1
        total_tokens = int(len(tissue_geoms))
        tokens_per_tile = max(1, total_tokens // n_tiles)
        idx = np.arange(total_tokens, dtype=np.int32)
        patch_id = np.floor_divide(idx, tokens_per_tile)
        patch_id = np.clip(patch_id, 0, max(0, n_tiles - 1))
        token_id = np.mod(idx, tokens_per_tile)
        tgdf["patch_id"] = patch_id
        tgdf["token_id"] = token_id
    shapes_layer = ShapesModel.parse(tgdf)

    # Read existing dataset if possible
    existing = None
    try:
        if hasattr(sdata, "read_zarr"):
            existing = sdata.read_zarr(zarr_path)  # type: ignore
        elif hasattr(sdata, "io") and hasattr(sdata.io, "read_zarr"):
            existing = sdata.io.read_zarr(zarr_path)  # type: ignore
    except Exception:
        existing = None

    if existing is None:
        # Create a new dataset with only shapes at the same path
        sd = SpatialData(shapes={region_name: shapes_layer})
        if hasattr(sd, "write"):
            # Pre-clear any existing non-Zarr path to honor overwrite semantics
            try:
                zp = Path(zarr_path)
                if zp.exists():
                    if zp.is_dir():
                        from shutil import rmtree
                        rmtree(zp, ignore_errors=True)
                    else:
                        zp.unlink(missing_ok=True)
            except Exception:
                pass
            sd.write(zarr_path, overwrite=True)
        else:
            raise RuntimeError("SpatialData.write(...) not available; please upgrade spatialdata to >=0.5.0.")
        return

    # Update shapes in existing dataset
    shapes = dict(getattr(existing, "shapes", {}) or {})
    shapes[region_name] = shapes_layer
    sd = SpatialData(
        shapes=shapes if shapes else None,
        images=getattr(existing, "images", None) or None,
        labels=getattr(existing, "labels", None) or None,
        points=getattr(existing, "points", None) or None,
        tables=getattr(existing, "tables", None) or None,
    )

    # Write updated dataset back
    if hasattr(sd, "write"):
        # Avoid store-in-use issues and honor overwrite semantics
        try:
            del existing
        except Exception:
            pass
        try:
            zp = Path(zarr_path)
            if zp.exists():
                if zp.is_dir():
                    from shutil import rmtree
                    rmtree(zp, ignore_errors=True)
                else:
                    zp.unlink(missing_ok=True)
        except Exception:
            pass
        sd.write(zarr_path, overwrite=True)
    else:
        raise RuntimeError("SpatialData.write(...) not available; please upgrade spatialdata to >=0.5.0.")


def write_spatialdata(sd_obj, out_path: str, overwrite: bool = True) -> None:
    """Write SpatialData using the recent API only.

    Defaults to `sd_obj.write(out_path, overwrite=True)` (spatialdata>=0.5.0).
    """
    if not HAS_SPATIALDATA:
        raise RuntimeError("SpatialData not available inside write_spatialdata.")

    from pathlib import Path
    outp = Path(out_path)
    try:
        from shutil import rmtree
        if overwrite and outp.exists():
            if outp.is_dir():
                rmtree(outp, ignore_errors=True)
            else:
                outp.unlink(missing_ok=True)
    except Exception:
        # Non-fatal; writer should still attempt to overwrite
        pass

    try:
        import spatialdata as sdata_mod  # type: ignore
    except Exception as e:
        raise RuntimeError("spatialdata is required to write SpatialData Zarr.") from e

    # Defensive normalization: ensure phenotype column is string-typed if present on token shapes
    try:
        import pandas as pd
        if hasattr(sd_obj, "shapes") and isinstance(getattr(sd_obj, "shapes"), dict):
            for k, gdf in sd_obj.shapes.items():
                try:
                    # Avoid touching geometry column; only coerce phenotype if present and not numeric
                    if "phenotype" in gdf.columns:
                        col = gdf["phenotype"]
                        if not pd.api.types.is_numeric_dtype(col.dtype):
                            try:
                                gdf["phenotype"] = pd.Series(col.values, index=gdf.index).astype(pd.StringDtype())
                            except Exception:
                                gdf["phenotype"] = pd.Series(col.values, index=gdf.index).astype(str)
                            sd_obj.shapes[k] = gdf
                except Exception:
                    pass
        # Also normalize phenotype in tables' obs, which Arrow writes via Parquet
        if hasattr(sd_obj, "tables") and isinstance(getattr(sd_obj, "tables"), dict):
            for tname, tbl in sd_obj.tables.items():
                try:
                    obs = getattr(tbl, "obs", None)
                    if obs is not None and "phenotype" in obs.columns:
                        col = obs["phenotype"]
                        if not pd.api.types.is_numeric_dtype(col.dtype):
                            try:
                                obs["phenotype"] = pd.Series(col.values, index=obs.index).astype(pd.StringDtype())
                            except Exception:
                                obs["phenotype"] = pd.Series(col.values, index=obs.index).astype(str)
                except Exception:
                    pass
    except Exception:
        pass

    if hasattr(sd_obj, "write"):
        sd_obj.write(out_path, overwrite=overwrite)
        return
    raise RuntimeError("SpatialData.write(...) not available on this object; please upgrade spatialdata to >=0.5.0.")