"""
cell_phenotype_mapper.py — Map per-cell phenotype from a SpatialData table
onto token shapes based on the 'cell_index' column.

Objective
- Add a column on `sdata.shapes['tokens']` where each token receives the
  value from `sdata.tables[table_key].obs[obs_column]` corresponding to the
  token's `cell_index`.
- Also optionally add a column with the matched `obs` index value for
  transparency and debugging.

Assumptions
- `tokens` shapes contain a `cell_index` column, as written by the
  token→cell mapper. It typically stores positional indices (0..n_cells-1),
  NaN for ambiguous, and possibly a sentinel (e.g., 0) for background.
- The table stored under `table_key` is an AnnData (or AnnData-like object)
  with a pandas `obs` DataFrame.

Usage (Notebook)
>>> from histotuner.cell_phenotype_mapper import cellPhenotypeMapper
>>> cellPhenotypeMapper(
...     sdata="path/to/spatial.zarr",
...     tableKey="image_mIF_cells",
...     obsColumn="phenotype",
...     tokensKey="tokens",
...     outColumn="cell_phenotype",
...     backgroundLabel=None,
...     outPath=None,
... )

CLI
python -m histotuner.cell_phenotype_mapper --sdataPath path/to.zarr \
  --tableKey image_mIF_cells --obsColumn phenotype --outColumn cell_phenotype
"""

from __future__ import annotations

import argparse
from pathlib import Path
from typing import Optional

import numpy as np
import pandas as pd

# Reuse the project-standard SpatialData writer from token_cell_mapper when available
try:
    from .token_cell_mapper import save_spatialdata_to_zarr as _save_sd  # type: ignore
except Exception:
    try:
        from histotuner.token_cell_mapper import save_spatialdata_to_zarr as _save_sd  # type: ignore
    except Exception:
        _save_sd = None

# Prefer centralized writer from spatial_writer when available
_write_sd = None
try:
    from .spatial_writer import write_spatialdata as _write_sd  # type: ignore
except Exception:
    try:
        from histotuner.spatial_writer import write_spatialdata as _write_sd  # type: ignore
    except Exception:
        _write_sd = None


def _warn(msg: str) -> None:
    print(f"[warn] {msg}")


def _info(msg: str) -> None:
    print(f"[info] {msg}")


def _resolve_table_as_anndata(tbl) -> "object":
    """Return AnnData or AnnData-like object exposing `.obs`.

    Handles TableModel wrappers by unwrapping common attributes.
    """
    try:
        import anndata as ad  # type: ignore
        if isinstance(tbl, ad.AnnData):
            return tbl
    except Exception:
        pass
    # Unwrap possible TableModel variants
    for attr in ("table", "adata", "AnnData", "data"):
        try:
            inner = getattr(tbl, attr)
            if inner is not None:
                try:
                    import anndata as ad  # type: ignore
                    if isinstance(inner, ad.AnnData):
                        return inner
                except Exception:
                    pass
                if hasattr(inner, "obs"):
                    return inner
        except Exception:
            pass
    # AnnData-like fallback
    if hasattr(tbl, "obs"):
        return tbl
    raise TypeError("Provided table is not AnnData, nor a TableModel/AnnData-like exposing 'obs'.")


def _read_spatialdata(zarr_path: str):
    """Read SpatialData from a path, robust across spatialdata versions.

    Attempts, in order:
    - spatialdata.read_zarr
    - spatialdata.read (generic reader in some versions)
    - spatialdata.io.read_zarr
    - spatialdata_io.read_zarr (external package)
    - SpatialData.read_zarr / SpatialData.read (class-level variants)
    """
    from pathlib import Path
    p = Path(zarr_path)
    if not p.exists():
        raise FileNotFoundError(f"SpatialData path does not exist: {zarr_path}")

    last_err = None
    sd = None
    try:
        import spatialdata as sdata_mod  # type: ignore
    except Exception as e:
        last_err = e
        sdata_mod = None  # type: ignore

    # 1) Module-level read_zarr
    try:
        if sdata_mod is not None and hasattr(sdata_mod, "read_zarr"):
            sd = sdata_mod.read_zarr(zarr_path)  # type: ignore
    except Exception as e:
        last_err = e
        sd = None

    # 2) Module-level generic read
    if sd is None:
        try:
            if sdata_mod is not None and hasattr(sdata_mod, "read"):
                sd = sdata_mod.read(zarr_path)  # type: ignore
        except Exception as e:
            last_err = e
            sd = None

    # 3) io.read_zarr
    if sd is None:
        try:
            if sdata_mod is not None and hasattr(sdata_mod, "io") and hasattr(sdata_mod.io, "read_zarr"):
                sd = sdata_mod.io.read_zarr(zarr_path)  # type: ignore
        except Exception as e:
            last_err = e
            sd = None

    # 4) spatialdata_io.read_zarr (external package)
    if sd is None:
        try:
            import spatialdata_io as sdio  # type: ignore
            if hasattr(sdio, "read_zarr"):
                sd = sdio.read_zarr(zarr_path)  # type: ignore
        except Exception as e:
            last_err = e
            sd = None

    # 5) Class-level fallbacks
    if sd is None:
        try:
            if sdata_mod is not None:
                SpatialData = getattr(sdata_mod, "SpatialData", None)
            else:
                SpatialData = None
            if SpatialData is not None:
                if hasattr(SpatialData, "read_zarr"):
                    try:
                        sd = SpatialData.read_zarr(zarr_path)
                    except TypeError:
                        # Some variants use `read(path)`
                        sd = SpatialData.read(zarr_path) if hasattr(SpatialData, "read") else None
                elif hasattr(SpatialData, "read"):
                    sd = SpatialData.read(zarr_path)
        except Exception as e:
            last_err = e
            sd = None

    if sd is None:
        if last_err is not None:
            raise RuntimeError(
                f"Failed to read SpatialData from: {zarr_path}. Last error: {last_err}"
            )
        raise RuntimeError(f"Failed to read SpatialData from: {zarr_path}")

    return sd


def _save_spatialdata(sd_obj, path: str, overwrite: bool = True) -> None:
    """Write SpatialData using the recent API only."""
    # Prefer centralized writer when available
    _write_sd = None
    try:
        from .spatial_writer import write_spatialdata as _write_sd  # type: ignore
    except Exception:
        try:
            from histotuner.spatial_writer import write_spatialdata as _write_sd  # type: ignore
        except Exception:
            _write_sd = None
    if _write_sd is not None:
        _write_sd(sd_obj, path, overwrite=overwrite)
        return

    if hasattr(sd_obj, "write"):
        sd_obj.write(path, overwrite=overwrite)
        return
    raise RuntimeError("SpatialData.write(...) not available; please upgrade spatialdata to >=0.5.0.")


def attach_cell_phenotype(
    sdata,
    *,
    tableKey: str,
    obsColumn: str,
    tokenShapesKey: str = "tokens",
    cellIndexColumn: str = "cell_index",
    outColumn: Optional[str] = None,
    writeObsIndexColumn: bool = False,
    obsIndexOutColumn: str = "cell_obs_index",
    backgroundLabel: Optional[int] = 0,
):
    """Attach per-cell phenotype from a table's `obs[obsColumn]` onto token shapes.

    Parameters
    - sdata: SpatialData object
    - tableKey: name of the table in `sdata.tables` (e.g., 'image_mIF_cells')
    - obsColumn: column name from `adata.obs` to map (e.g., 'phenotype')
    - tokenShapesKey: shapes key for tokens (default 'tokens')
    - cellIndexColumn: column in tokens shapes that stores positional indices (default 'cell_index')
    - outColumn: target column name written on tokens; defaults to `obsColumn`
    - writeObsIndexColumn: also write matched `obs` index values on tokens (default True)
    - obsIndexOutColumn: column name for matched `obs` index values
    - backgroundLabel: sentinel in `cell_index` that should be treated as 'no match' (default None)

    Returns
    - pd.Series of mapped phenotype values aligned to tokens index.
    """
    # Validate shapes and table
    if tokenShapesKey not in sdata.shapes:
        raise KeyError(f"Tokens shapes key '{tokenShapesKey}' not found in SpatialData.")
    if tableKey not in sdata.tables:
        raise KeyError(f"Table key '{tableKey}' not found in SpatialData tables.")

    tokens_gdf = sdata.shapes[tokenShapesKey]
    table_obj = sdata.tables[tableKey]
    adata = _resolve_table_as_anndata(table_obj)

    if cellIndexColumn not in tokens_gdf.columns:
        raise KeyError(f"Tokens missing '{cellIndexColumn}' column; run token→cell mapper first.")
    # Resolve obs column name with normalization fallback
    obs_cols_list = [str(c) for c in adata.obs.columns]
    if str(obsColumn) in obs_cols_list:
        obs_col_use = str(obsColumn)
    else:
        norm = str(obsColumn).strip().lower()
        col_map = {str(c).strip().lower(): str(c) for c in adata.obs.columns}
        obs_col_use = col_map.get(norm, None)
        if obs_col_use is None:
            raise KeyError(
                f"Table '{tableKey}' missing obs column '{obsColumn}'. Available: {obs_cols_list}"
            )

    # Prepare source vectors
    obs_df = adata.obs
    obs_values = obs_df[obs_col_use]
    obs_index_values = obs_df.index.to_numpy()

    # Resolve token→cell positional indices
    ci = tokens_gdf[cellIndexColumn].to_numpy()
    # Determine valid positions
    valid_mask = ~pd.isna(ci)
    pos_idx = ci[valid_mask].astype(np.int64, copy=False)

    # Identify background positions (carry over)
    bg_local = (pos_idx == int(backgroundLabel)) if backgroundLabel is not None else np.zeros(pos_idx.shape, dtype=bool)

    # Clip out-of-range indices and exclude background from mapping
    valid_positions = np.flatnonzero(valid_mask)
    within = (pos_idx >= 0) & (pos_idx < obs_df.shape[0]) & (~bg_local)
    use_positions = valid_positions[within]
    use_idx = pos_idx[within]
    bg_positions = valid_positions[bg_local]

    # Build output series aligned to tokens index
    out_name = str(outColumn or obs_col_use)
    is_numeric = pd.api.types.is_numeric_dtype(obs_values.dtype)
    if is_numeric:
        result_values = np.full(tokens_gdf.shape[0], np.nan, dtype=float)
    else:
        result_values = np.empty(tokens_gdf.shape[0], dtype=object)
        result_values[:] = np.nan
    # Assign phenotype values
    try:
        vals = obs_values.iloc[use_idx].to_numpy()
    except Exception:
        vals = obs_values.iloc[use_idx].values
    if is_numeric:
        try:
            vals = pd.to_numeric(vals, errors="coerce").astype(float)
        except Exception:
            pass
    else:
        # Ensure non-numeric phenotype values are consistently strings
        try:
            vals = pd.Series(vals).astype(pd.StringDtype()).values
        except Exception:
            vals = pd.Series(vals).astype(str).values
    result_values[use_positions] = vals
    # Carry over background label (e.g., 0)
    if backgroundLabel is not None and bg_positions.size > 0:
        result_values[bg_positions] = (
            float(backgroundLabel) if is_numeric else str(backgroundLabel)
        )
    result = pd.Series(result_values, index=tokens_gdf.index, name=out_name)
    # Ensure missing are np.nan, not None
    try:
        result = result.replace({None: np.nan})
    except Exception:
        pass

    # Optionally also write matched obs index for transparency
    if writeObsIndexColumn:
        # Store numeric positions to guarantee float dtype and np.nan for missing values
        idx_values = np.empty(tokens_gdf.shape[0], dtype=float)
        idx_values[:] = np.nan
        idx_values[use_positions] = use_idx.astype(float)
        idx_series = pd.Series(idx_values, index=tokens_gdf.index, name=str(obsIndexOutColumn))
    else:
        idx_series = None

    # Persist onto SpatialData shapes
    updated_tokens = tokens_gdf.copy()
    # Coerce final column to string dtype for non-numeric phenotype to satisfy Arrow
    if not is_numeric:
        try:
            updated_tokens[out_name] = pd.Series(result.values, index=updated_tokens.index).astype(pd.StringDtype())
        except Exception:
            updated_tokens[out_name] = pd.Series(result.values, index=updated_tokens.index).astype(str)
    else:
        updated_tokens[out_name] = result.values
    if idx_series is not None:
        updated_tokens[str(obsIndexOutColumn)] = idx_series.values
    sdata.shapes[tokenShapesKey] = updated_tokens

    return result


def attach_cell_phenotype_to_zarr(
    zarr_path: str,
    *,
    tableKey: str,
    obsColumn: str,
    tokenShapesKey: str = "tokens",
    cellIndexColumn: str = "cell_index",
    outColumn: Optional[str] = None,
    writeObsIndexColumn: bool = False,
    obsIndexOutColumn: str = "cell_obs_index",
    backgroundLabel: Optional[int] = 0,
    outPath: Optional[str] = None,
    overwrite: bool = True,
):
    """Notebook-friendly wrapper: load Zarr, attach phenotype mapping, optionally persist.

    Returns the mapped phenotype Series.
    """
    sd = _read_spatialdata(zarr_path)
    series = attach_cell_phenotype(
        sd,
        tableKey=tableKey,
        obsColumn=obsColumn,
        tokenShapesKey=tokenShapesKey,
        cellIndexColumn=cellIndexColumn,
        outColumn=outColumn,
        writeObsIndexColumn=writeObsIndexColumn,
        obsIndexOutColumn=obsIndexOutColumn,
        backgroundLabel=backgroundLabel,
    )

    # Persist
    target = outPath or zarr_path
    if _write_sd is not None:
        try:
            _write_sd(sd, target, overwrite=True)
        except Exception:
            if _save_sd is not None:
                try:
                    _save_sd(sd, target, overwrite=True)
                except Exception:
                    _save_spatialdata(sd, target, overwrite=overwrite)
            else:
                _save_spatialdata(sd, target, overwrite=overwrite)
    elif _save_sd is not None:
        try:
            _save_sd(sd, target, overwrite=True)
        except Exception:
            _save_spatialdata(sd, target, overwrite=overwrite)
    else:
        _save_spatialdata(sd, target, overwrite=overwrite)
    return series


def phenotypeCellMapper(
    sdata: Optional[object] = None,
    *,
    tokenShapesKey: str = "tokens",
    tableKey: str,
    obsColumn: str,
    outColumn: Optional[str] = None,
    cellIndexColumn: str = "cell_index",
    noMatchLabel: int = 0,
    writeObsIndexColumn: bool = False,
    obsIndexOutColumn: str = "cell_obs_index",
    verbose: bool = False,
    outPath: Optional[str] = None,
    overwrite: bool = True,
):
    """Notebook/CLI-friendly unified API (camelCase) matching tokenCellMapper style.

    - If `sdata` is a path string, reads SpatialData from that path.
    - Writes to `outPath` if provided, else returns modified SpatialData in-memory.
    - Parameters mirror `tokenCellMapper` where possible.
    """
    if sdata is None:
        raise ValueError("Provide sdata as a SpatialData object or a Zarr path string.")
    sd_obj = sdata
    if isinstance(sdata, str):
        sd_obj = _read_spatialdata(sdata)

    # Validate keys
    if tokenShapesKey not in sd_obj.shapes:
        raise KeyError(f"Tokens shapes key '{tokenShapesKey}' not found in SpatialData.")
    if tableKey not in sd_obj.tables:
        raise KeyError(f"Table key '{tableKey}' not found in SpatialData tables.")

    _ = attach_cell_phenotype(
        sd_obj,
        tableKey=tableKey,
        obsColumn=obsColumn,
        tokenShapesKey=tokenShapesKey,
        cellIndexColumn=cellIndexColumn,
        outColumn=outColumn,
        writeObsIndexColumn=writeObsIndexColumn,
        obsIndexOutColumn=obsIndexOutColumn,
        backgroundLabel=noMatchLabel,
    )

    if outPath is None:
        return sd_obj
    target = outPath if isinstance(outPath, str) else str(outPath)
    if _write_sd is not None:
        try:
            _write_sd(sd_obj, target, overwrite=True)
        except Exception:
            if _save_sd is not None:
                try:
                    _save_sd(sd_obj, target, overwrite=True)
                except Exception:
                    _save_spatialdata(sd_obj, target, overwrite=overwrite)
            else:
                _save_spatialdata(sd_obj, target, overwrite=overwrite)
    elif _save_sd is not None:
        try:
            _save_sd(sd_obj, target, overwrite=True)
        except Exception:
            _save_spatialdata(sd_obj, target, overwrite=overwrite)
    else:
        _save_spatialdata(sd_obj, target, overwrite=overwrite)
    return target

# Backwards-compatible alias for earlier naming
def cellPhenotypeMapper(
    sdata: Optional[object] = None,
    *,
    tableKey: str,
    obsColumn: str,
    tokenShapesKey: str = "tokens",
    cellIndexColumn: str = "cell_index",
    outColumn: Optional[str] = None,
    writeObsIndexColumn: bool = True,
    obsIndexOutColumn: str = "cell_obs_index",
    backgroundLabel: Optional[int] = 0,
    outPath: Optional[str] = None,
    overwrite: bool = True,
):
    return phenotypeCellMapper(
        sdata=sdata,
        tokenShapesKey=tokenShapesKey,
        tableKey=tableKey,
        obsColumn=obsColumn,
        outColumn=outColumn,
        cellIndexColumn=cellIndexColumn,
        noMatchLabel=(0 if backgroundLabel is None else int(backgroundLabel)),
        writeObsIndexColumn=writeObsIndexColumn,
        obsIndexOutColumn=obsIndexOutColumn,
        outPath=outPath,
        overwrite=overwrite,
    )


def main() -> None:
    ap = argparse.ArgumentParser(description="Map per-cell phenotype from SpatialData table onto tokens.")
    ap.add_argument("--sdataPath", required=True, help="Path to input SpatialData Zarr root")
    ap.add_argument("--tableKey", required=True, help="Table key in SpatialData tables (e.g. 'image_mIF_cells')")
    ap.add_argument("--obsColumn", required=True, help="Column in table.obs to map (e.g. 'phenotype')")
    ap.add_argument("--tokenShapesKey", default="tokens", help="Shapes key for tokens (default 'tokens')")
    ap.add_argument("--cellIndexColumn", default="cell_index", help="Column in tokens that stores positional cell indices")
    ap.add_argument("--outColumn", default=None, help="Column name written on tokens; defaults to obsColumn")
    ap.add_argument("--writeObsIndexColumn", action="store_true", help="Also write matched obs index values on tokens")
    ap.add_argument("--noWriteObsIndexColumn", dest="writeObsIndexColumn", action="store_false")
    ap.set_defaults(writeObsIndexColumn=False)
    ap.add_argument("--obsIndexOutColumn", default="cell_obs_index", help="Column name for matched obs index values")
    ap.add_argument("--backgroundLabel", type=int, default=0, help="Sentinel to treat as background/no match (default 0)")
    ap.add_argument("--outPath", default=None, help="Optional output Zarr path; if provided, overwrites there")
    ap.add_argument("--overwrite", dest="overwrite", action="store_true", help="Overwrite output path if it exists")
    ap.add_argument("--noOverwrite", dest="overwrite", action="store_false")
    ap.set_defaults(overwrite=True)

    args = ap.parse_args()
    series = attach_cell_phenotype_to_zarr(
        zarr_path=str(args.sdataPath),
        tableKey=str(args.tableKey),
        obsColumn=str(args.obsColumn),
        tokenShapesKey=str(args.tokenShapesKey),
        cellIndexColumn=str(args.cellIndexColumn),
        outColumn=(None if args.outColumn in (None, "None") else str(args.outColumn)),
        writeObsIndexColumn=bool(args.writeObsIndexColumn),
        obsIndexOutColumn=str(args.obsIndexOutColumn),
        backgroundLabel=(None if args.backgroundLabel in (None, "None") else int(args.backgroundLabel)),
        outPath=(None if args.outPath in (None, "None") else str(args.outPath)),
        overwrite=bool(args.overwrite),
    )
    # Brief summary
    total = int(series.shape[0])
    mapped = int(np.sum(~pd.isna(series)))
    _info(f"Cell phenotype mapping complete: tokens={total} mapped={mapped} unmapped={total - mapped}")


if __name__ == "__main__":
    main()