# token_cell_alignment.py
from __future__ import annotations
import argparse
import math
import numpy as np
import pandas as pd
import geopandas as gpd
from shapely.strtree import STRtree
from shapely.geometry.base import BaseGeometry
from typing import Optional, Tuple
from tqdm import tqdm

def _repair_invalid(geoms: pd.Series) -> pd.Series:
    """
    Lazily repair invalid polygons. Uses buffer(0) only where needed to avoid
    unnecessary topology changes and runtime cost.
    """
    invalid = ~geoms.is_valid
    if invalid.any():
        geoms.loc[invalid] = geoms.loc[invalid].buffer(0)
    return geoms

def token_to_cell_mapper(
    cells_gdf: gpd.GeoDataFrame,
    tokens_gdf: gpd.GeoDataFrame,
    *,
    min_overlap: float = 0.05,          # ≥5% token-area to be eligible
    dominance_gap: float = 0.10,        # top − second ≥10% absolute to assign
    no_match_label: int = 0,            # background
    ambiguous_label: float = np.nan,    # ambiguous
    batch_size: int = 100_000,          # tune based on RAM
    show_progress: bool = True,
) -> pd.Series:
    """
    Compute token→cell mapping by polygonal area overlap.

    Returns
    -------
    pd.Series
        Indexed by tokens_gdf.index, values are the matched cell index (int),
        '0' for background (as integer 0), or NaN for ambiguous.
    """
    # Basic checks
    assert 'geometry' in cells_gdf, "cells_gdf must be a GeoDataFrame with geometry."
    assert 'geometry' in tokens_gdf, "tokens_gdf must be a GeoDataFrame with geometry."
    if not isinstance(cells_gdf.geometry.iloc[0], BaseGeometry):
        raise ValueError("cells_gdf.geometry must contain shapely geometries.")
    if not isinstance(tokens_gdf.geometry.iloc[0], BaseGeometry):
        raise ValueError("tokens_gdf.geometry must contain shapely geometries.")

    # Ensure consistent CRS (SpatialData uses registered coordinate systems; here we assume both are in same one)
    if (cells_gdf.crs is not None) and (tokens_gdf.crs is not None):
        if cells_gdf.crs != tokens_gdf.crs:
            raise ValueError(f"CRS mismatch: cells={cells_gdf.crs} vs tokens={tokens_gdf.crs}")

    # Lazily repair invalid geometries if any
    cells = cells_gdf.geometry.copy()
    tokens = tokens_gdf.geometry.copy()
    cells = _repair_invalid(cells)
    tokens = _repair_invalid(tokens)

    # Precompute token areas; zero-area tokens are trivially background
    token_area = tokens.area.values
    token_zero_area = token_area <= 0

    # Build STRtree over cells (built once, reused across all token batches)
    cell_geoms = cells.values
    cell_indexer = np.arange(len(cells_gdf))  # map tree positions → cells_gdf.index
    tree = STRtree(cell_geoms)
    has_bulk = hasattr(tree, "query_bulk")
    id_map = {id(g): i for i, g in enumerate(cell_geoms)} if not has_bulk else None
    wkb_map = None  # lazily constructed if needed in fallback

    # Output container
    out = np.empty(len(tokens_gdf), dtype=float)  # will hold ints/NaN
    out.fill(np.nan)  # default to NaN; we will set background or cell id

    # Fast-path zero-area tokens → background
    out[token_zero_area] = no_match_label

    # Batch over the remaining tokens
    token_idx_array = np.arange(len(tokens_gdf))
    todo_mask = ~token_zero_area
    todo_indices = token_idx_array[todo_mask]
    n = len(todo_indices)

    rng = range(0, n, batch_size)
    iterator = tqdm(rng, desc="Mapping tokens→cells", disable=not show_progress)
    for start in iterator:
        stop = min(start + batch_size, n)
        batch_local = todo_indices[start:stop]               # positions in tokens_gdf
        batch_geoms = tokens.iloc[batch_local].values
        batch_areas = token_area[batch_local]

        # STRtree candidate lookups
        if has_bulk:
            # Vectorized (Shapely 2.x)
            query_ix, tree_ix = tree.query_bulk(batch_geoms)
        else:
            # Fallback for Shapely 1.x: per-geometry query
            qix_list = []
            tix_list = []
            for local_pos, tok in enumerate(batch_geoms):
                try:
                    candidates = tree.query(tok)
                except Exception:
                    candidates = []
                # Robust empty check across list/tuple/ndarray
                is_empty = False
                if candidates is None:
                    is_empty = True
                elif isinstance(candidates, (list, tuple)):
                    is_empty = (len(candidates) == 0)
                elif hasattr(candidates, "size"):
                    is_empty = (candidates.size == 0)
                if is_empty:
                    continue

                # Shapely 2.x query returns indices (ndarray); Shapely 1.x returns geometries
                if isinstance(candidates, np.ndarray):
                    for cp in candidates.astype(int):
                        qix_list.append(local_pos)
                        tix_list.append(int(cp))
                else:
                    for c in candidates:
                        cp = id_map.get(id(c)) if id_map is not None else None
                        if cp is None:
                            if wkb_map is None:
                                # Build once lazily
                                wkb_map = {g.wkb: i for i, g in enumerate(cell_geoms)}
                            cp = wkb_map.get(c.wkb)
                        if cp is None:
                            continue
                        qix_list.append(local_pos)
                        tix_list.append(cp)
            query_ix = np.array(qix_list, dtype=int)
            tree_ix = np.array(tix_list, dtype=int)

        if query_ix.size == 0:
            # No candidates for this batch → all background by threshold rule
            out[batch_local] = no_match_label
            continue

        # Filter any invalid query indices defensively
        valid_mask = (query_ix >= 0) & (query_ix < len(batch_local))
        if not np.all(valid_mask):
            query_ix = query_ix[valid_mask]
            tree_ix = tree_ix[valid_mask]

        # Compute intersections for candidate pairs
        tok_candidates = batch_geoms[query_ix]
        cell_candidates = cell_geoms[tree_ix]
        if has_bulk:
            intersections = tok_candidates.intersection(cell_candidates)
            inter_areas = intersections.area
        else:
            # Compute intersections per pair (Shapely 1.x)
            inter_areas = np.array([
                (tok_candidates[i].intersection(cell_candidates[i]).area if tok_candidates[i] is not None else 0.0)
                for i in range(len(query_ix))
            ], dtype=float)

        # Group by token (within the batch), find top & runner-up overlaps (as fractions of token area)
        # Build a dataframe to exploit fast groupby operations
        abs_ix = batch_local[query_ix]  # absolute token indices in tokens_gdf
        df = pd.DataFrame({
            "token_idx": abs_ix,    # absolute index into tokens_gdf
            "cell_pos": tree_ix,    # 0..len(cells)-1 (tree order)
            "inter_area": inter_areas
        })

        # Compute fractional overlaps relative to the token area
        # Map absolute token indices → token area
        frac = df["inter_area"].values / token_area[df["token_idx"].values]
        df["frac"] = frac

        # For each token in batch, get top two candidates
        # Sort by frac descending, then take first & second
        df_sorted = df.sort_values(["token_idx", "frac"], ascending=[True, False])

        # Aggregate: top1 cell + frac, top2 frac
        top1 = df_sorted.groupby("token_idx", sort=False).nth(0)[["cell_pos", "frac"]]
        top2 = df_sorted.groupby("token_idx", sort=False).nth(1)[["frac"]].rename(columns={"frac": "frac2"})
        agg = top1.join(top2, how="left").fillna({"frac2": 0.0})

        # Decide per-token label
        abs_token_idx = agg.index.values  # already absolute indices
        top_cell_pos = agg["cell_pos"].astype(int).values
        top_frac = agg["frac"].values
        second_frac = agg["frac2"].values

        # Threshold logic
        assign_bg = top_frac < min_overlap
        clear_winner = (top_frac - second_frac) >= dominance_gap
        # Initialize with NaN (ambiguous), then set background or a cell
        labels = np.full(agg.shape[0], np.nan, dtype=float)
        labels[assign_bg] = no_match_label
        labels[~assign_bg & clear_winner] = cell_indexer[top_cell_pos[~assign_bg & clear_winner]]

        # Background for tokens with no candidates in this batch
        seen_abs = np.unique(abs_ix)
        no_cand_abs = batch_local[~np.isin(batch_local, seen_abs)]
        if no_cand_abs.size > 0:
            out[no_cand_abs] = no_match_label

        # Write back into out for tokens with candidates (guard against any stray OOB)
        in_bounds = (abs_token_idx >= 0) & (abs_token_idx < out.shape[0])
        if not np.all(in_bounds):
            abs_token_idx = abs_token_idx[in_bounds]
            labels = labels[in_bounds]
        out[abs_token_idx] = labels

        # Free intermediates ASAP
        if has_bulk:
            del intersections
        del df, df_sorted, agg

    # Ensure integer 0 for background and NaN for ambiguous
    # (cells_gdf.index could be non-contiguous; we store the *index* values)
    result = pd.Series(out, index=tokens_gdf.index, name="cell_index")
    # Ensure missing values are np.nan, not None
    try:
        result = result.replace({None: np.nan})
    except Exception:
        pass
    # cast backgrounds to int 0; keep others as floats to allow NaN
    bg_mask = result.eq(no_match_label)
    result.loc[bg_mask] = int(no_match_label)
    return result  # dtype float; contains ints for assigned cells & 0, NaN for ambiguous


# ---------------------------
# SpatialData integration API
# ---------------------------

def attach_token_cell_alignment(
    sdata,
    *,
    token_shapes_key: str = "tokens",
    cell_shapes_key: str = "segmentation_mask",
    out_column: str = "cell_index",
    min_overlap: float = 0.05,
    dominance_gap: float = 0.10,
    batch_size: int = 100_000,
    show_progress: bool = True,
) -> pd.Series:
    """
    Compute and attach mapping as a new column on the tokens Shapes element
    within a SpatialData object. Returns the mapping Series for convenience.

    Notes on persistence:
    - In-memory: after calling this, sdata holds the column.
    - To persist to Zarr, use `write_zarr(sdata, path, overwrite=True)` (see helper below).
    """
    tokens_gdf: gpd.GeoDataFrame = sdata.shapes[token_shapes_key]
    cells_gdf: gpd.GeoDataFrame = sdata.shapes[cell_shapes_key]

    mapping = token_to_cell_mapper(
        cells_gdf=cells_gdf,
        tokens_gdf=tokens_gdf,
        min_overlap=min_overlap,
        dominance_gap=dominance_gap,
        batch_size=batch_size,
        show_progress=show_progress,
    )
    # Attach as a column (preserving alignment to token index)
    updated_tokens = tokens_gdf.copy()
    updated_tokens[out_column] = mapping.reindex(updated_tokens.index).values
    sdata.shapes[token_shapes_key] = updated_tokens
    return mapping


def save_spatialdata_to_zarr(sdata_obj, path: str, overwrite: bool = True) -> None:
    """Persist SpatialData using the recent API only.

    Delegates to the centralized writer when available; otherwise calls
    `sdata_obj.write(path, overwrite=True)` (spatialdata>=0.5.0).
    """
    # Prefer centralized writer from spatial_writer when available
    _write_sd = None
    try:
        from .spatial_writer import write_spatialdata as _write_sd  # type: ignore
    except Exception:
        try:
            from histotuner.spatial_writer import write_spatialdata as _write_sd  # type: ignore
        except Exception:
            _write_sd = None
    if _write_sd is not None:
        _write_sd(sdata_obj, path, overwrite=overwrite)
        return

    if hasattr(sdata_obj, "write"):
        sdata_obj.write(path, overwrite=overwrite)
        return
    raise RuntimeError("SpatialData.write(...) not available; please upgrade spatialdata to >=0.5.0.")


def attach_token_cell_alignment_to_zarr(
    zarr_path: str,
    *,
    token_shapes_key: str = "tokens",
    cell_shapes_key: str = "segmentation_mask",
    out_column: str = "cell_index",
    min_overlap: float = 0.05,
    dominance_gap: float = 0.10,
    batch_size: int = 100_000,
    show_progress: bool = True,
):
    """Notebook-friendly wrapper: load SpatialData Zarr, attach token→cell mapping, and persist.

    Parameters
    - zarr_path: Path to SpatialData Zarr root
    - token_shapes_key: Shapes key for tokens (default 'tokens')
    - cell_shapes_key: Shapes key for cells (default 'segmentation_mask')
    - out_column: Column name to write on tokens (default 'cell_index')
    - min_overlap: Minimum fractional overlap to be eligible (default 0.05)
    - dominance_gap: Required gap between top and second overlap (default 0.10)
    - batch_size: Token batch size for processing (default 100000)
    - show_progress: Show tqdm progress (default True)

    Returns
    - pd.Series mapping (aligned to tokens index) and persists the updated SpatialData.
    """
    try:
        import spatialdata as sdata_mod  # type: ignore
    except Exception as e:
        raise RuntimeError("spatialdata is required to read/write SpatialData Zarr.") from e

    # Read SpatialData from path (robust across versions)
    sd = None
    try:
        if hasattr(sdata_mod, "read_zarr"):
            sd = sdata_mod.read_zarr(zarr_path)  # type: ignore
        elif hasattr(sdata_mod, "io") and hasattr(sdata_mod.io, "read_zarr"):
            sd = sdata_mod.io.read_zarr(zarr_path)  # type: ignore
    except Exception:
        sd = None
    if sd is None:
        raise RuntimeError(f"Failed to read SpatialData from: {zarr_path}")

    # Validate shapes keys exist
    if token_shapes_key not in sd.shapes:
        raise KeyError(f"Tokens shapes key '{token_shapes_key}' not found in SpatialData.")
    if cell_shapes_key not in sd.shapes:
        raise KeyError(f"Cells shapes key '{cell_shapes_key}' not found in SpatialData.")

    # Attach mapping in-memory
    mapping = attach_token_cell_alignment(
        sd,
        token_shapes_key=token_shapes_key,
        cell_shapes_key=cell_shapes_key,
        out_column=out_column,
        min_overlap=min_overlap,
        dominance_gap=dominance_gap,
        batch_size=batch_size,
        show_progress=show_progress,
    )

    # Persist back to the same Zarr path
    try:
        save_spatialdata_to_zarr(sd, zarr_path, overwrite=True)
    except Exception as e:
        # Final fallback: attempt sd.write if available
        try:
            sd.write(zarr_path, overwrite=True)
        except Exception:
            raise e
    return mapping


"""
Notebook/CLI-friendly unified API (camelCase)

tokenCellMapper operates on:
- A SpatialData object in-memory (sdata is SpatialData)
- A path to a SpatialData Zarr on disk (sdata is str)

Persistence semantics:
- If outPath is None → returns modified SpatialData (in-memory).
- If outPath is provided → writes to outPath with overwrite, returns the outPath.
"""

def tokenCellMapper(
    sdata: Optional[object] = None,
    *,
    tokenShapesKey: str = "tokens",
    cellShapesKey: str = "segmentation_mask",
    outColumn: str = "cell_index",
    minOverlap: float = 0.05,
    dominanceGap: float = 0.10,
    noMatchLabel: int = 0,
    ambiguousLabel: float = np.nan,
    batchSize: int = 100_000,
    verbose: bool = False,
    outPath: Optional[str] = None,
):
    # SpatialData object or path required
    if sdata is None:
        raise ValueError("Provide sdata as a SpatialData object or a Zarr path string.")

    # Resolve SpatialData object
    sd_obj = None
    if isinstance(sdata, str):
        # Read from path
        try:
            import spatialdata as sdata_mod  # type: ignore
        except Exception as e:
            raise RuntimeError("spatialdata is required to read/write SpatialData Zarr.") from e
        # robust read across versions
        if hasattr(sdata_mod, "read_zarr"):
            sd_obj = sdata_mod.read_zarr(sdata)  # type: ignore
        elif hasattr(sdata_mod, "io") and hasattr(sdata_mod.io, "read_zarr"):
            sd_obj = sdata_mod.io.read_zarr(sdata)  # type: ignore
        else:
            raise RuntimeError("No SpatialData reader found in installed spatialdata.")
    else:
        # Assume already a SpatialData object
        sd_obj = sdata

    # Validate shapes keys exist
    if tokenShapesKey not in sd_obj.shapes:
        raise KeyError(f"Tokens shapes key '{tokenShapesKey}' not found in SpatialData.")
    if cellShapesKey not in sd_obj.shapes:
        raise KeyError(f"Cells shapes key '{cellShapesKey}' not found in SpatialData.")

    tokens_gdf: gpd.GeoDataFrame = sd_obj.shapes[tokenShapesKey]
    cells_gdf: gpd.GeoDataFrame = sd_obj.shapes[cellShapesKey]

    # Compute mapping
    mapping = token_to_cell_mapper(
        cells_gdf=cells_gdf,
        tokens_gdf=tokens_gdf,
        min_overlap=minOverlap,
        dominance_gap=dominanceGap,
        no_match_label=noMatchLabel,
        ambiguous_label=ambiguousLabel,
        batch_size=batchSize,
        show_progress=verbose,
    )

    # Attach mapping to SpatialData tokens shapes
    updated_tokens = tokens_gdf.copy()
    updated_tokens[outColumn] = mapping.reindex(updated_tokens.index).values
    sd_obj.shapes[tokenShapesKey] = updated_tokens

    # Handle output
    if outPath is None:
        # Return modified SpatialData in memory
        return sd_obj
    else:
        # Persist to outPath (overwrite if exists)
        save_spatialdata_to_zarr(sd_obj, outPath, overwrite=True)
        return outPath


def _print_mapping_summary(mapping: pd.Series) -> None:
    total = int(mapping.shape[0])
    bg = int(np.sum(mapping.fillna(1) == 0))
    amb = int(np.sum(mapping.isna()))
    assigned = total - bg - amb
    print(f"[info] Token→cell mapping summary: total={total} assigned={assigned} background={bg} ambiguous={amb}")


def main() -> None:
    ap = argparse.ArgumentParser(description="Run token→cell mapping on a SpatialData Zarr.")
    ap.add_argument("--sdataPath", required=True, help="Path to input SpatialData Zarr root")
    ap.add_argument("--outPath", default=None, help="Optional output Zarr path; if provided, overwrites there")
    ap.add_argument("--tokenShapesKey", default="tokens", help="Shapes key for tokens (default 'tokens')")
    ap.add_argument("--cellShapesKey", default="segmentation_mask", help="Shapes key for cells (default 'segmentation_mask')")
    ap.add_argument("--outColumn", default="cell_index", help="Column name written on tokens (default 'cell_index')")
    ap.add_argument("--minOverlap", type=float, default=0.05, help="Minimum fractional overlap to be eligible (default 0.05)")
    ap.add_argument("--dominanceGap", type=float, default=0.10, help="Required gap between top and second overlap (default 0.10)")
    ap.add_argument("--batchSize", type=int, default=100_000, help="Token batch size (default 100000)")
    ap.add_argument("--verbose", dest="verbose", action="store_true", help="Enable progress output")
    ap.add_argument("--quiet", dest="verbose", action="store_false", help="Disable progress output")
    ap.set_defaults(verbose=False)
    args = ap.parse_args()

    result = tokenCellMapper(
        sdata=str(args.sdataPath),
        tokenShapesKey=str(args.tokenShapesKey),
        cellShapesKey=str(args.cellShapesKey),
        outColumn=str(args.outColumn),
        minOverlap=float(args.minOverlap),
        dominanceGap=float(args.dominanceGap),
        batchSize=int(args.batchSize),
        verbose=bool(args.verbose),
        outPath=(None if args.outPath is None else str(args.outPath)),
    )
    # If outPath given, result is the path; otherwise result is SpatialData. Print summary.
    if isinstance(result, str):
        print(f"[info] Mapping attached and written to: {result}")
    else:
        mapping = result.shapes[str(args.tokenShapesKey)][str(args.outColumn)]
        _print_mapping_summary(mapping)


if __name__ == "__main__":
    main()
