# hoptimus_finetune.py
# Fine-tunes a ViT patch16 model (e.g., H-optimus-1) for token-level (14x14) classification
# from SpatialData shapes, reading 224x224 tiles from an external WSI via OpenSlide.
# Run via: python -c "import hoptimus_finetune as m; m.run_from_config('config.yaml')"

import os, json, math, warnings
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Optional, Any

import numpy as np
import pandas as pd
import geopandas as gpd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision.transforms import functional as TF
import torchvision.transforms as T

# pip install pyyaml spatialdata[io] zarr xarray geopandas shapely openslide-python timm safetensors
import yaml
from spatialdata import SpatialData
import timm
import openslide
from safetensors.torch import load_file as safetensors_load
# Allow running as a script while importing package modules
import sys
PKG_SRC = Path(__file__).resolve().parents[1]
if str(PKG_SRC) not in sys.path:
    sys.path.insert(0, str(PKG_SRC))
try:
    import histotuner.image_patcher as image_patcher
except Exception:
    image_patcher = None
# Optional: progress bar
HAS_TQDM = False
try:
    from tqdm.auto import tqdm
    HAS_TQDM = True
except Exception:
    HAS_TQDM = False

# AMP support
HAS_AMP = False
try:
    from torch import amp
    HAS_AMP = True
except Exception:
    HAS_AMP = False

from contextlib import nullcontext


IGNORE_INDEX_DEFAULT = -100
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


# ========== Config ==========
@dataclass
class OptimConfig:
    batch_size: int = 48
    epochs_head: int = 2
    epochs_full: int = 8
    lr_head: float = 1e-3
    lr_backbone: float = 1e-5
    weight_decay: float = 0.05
    val_fraction: float = 0.15
    seed: int = 17
    grad_clip: float = 1.0
    num_workers: int = 0
    stage1_sched: str = "none"  # options: none, onecycle, cosine
    stage1_max_lr: Optional[float] = None
    early_stopping_patience: int = 0  # 0 disables early stopping in Stage 2

@dataclass
class DataConfig:
    sdata_path: Optional[str] = None
    wsi_path: Optional[str] = None
    tiles_key: str = "tiles"     # <-- expose shapes name
    tokens_key: str = "tokens"   # <-- expose shapes name
    label_col: str = "clusters"
    min_fg_tokens: int = 0       # filter tiles with <k foreground (non-0, non-ignore) tokens
    background_id: int = 0
    ignore_index: int = IGNORE_INDEX_DEFAULT
    mean: List[float] = field(default_factory=lambda: [0.5,0.5,0.5])
    std:  List[float] = field(default_factory=lambda: [0.5,0.5,0.5])
    augment: bool = True
    input_size: List[int] = field(default_factory=lambda: [224, 224])
    target_mpp: Optional[float] = None
    source_mpp: Optional[float] = None
    # Multi-image support: optional list of per-slide configs
    slides: Optional[List[Dict[str, Any]]] = None

@dataclass
class ModelConfig:
    model_id: str = "vit_base_patch16_224"
    timm_kwargs: Optional[Dict[str, Any]] = None
    checkpoint_path: Optional[str] = None
    freeze_stage1: bool = True
    head_hidden_dim: Optional[int] = None
    head_dropout: float = 0.1

@dataclass
class SaveConfig:
    output_path: str
    wandb_enabled: bool = False
    wandb_project: Optional[str] = None
    wandb_run_name: Optional[str] = None

@dataclass
class Config:
    data: DataConfig
    model: ModelConfig
    optim: OptimConfig
    save: SaveConfig


# ========== Utils ==========
def load_config(path: str) -> Config:
    with open(path, "r") as f:
        cfg = yaml.safe_load(f)
    # Support nested dict -> dataclasses
    return Config(
        data=DataConfig(**cfg["data"]),
        model=ModelConfig(**cfg["model"]),
        optim=OptimConfig(**cfg["optim"]),
        save=SaveConfig(**cfg["save"]),
    )

def infer_classes(series: pd.Series):
    s = series.dropna()
    if s.empty:
        return 0, {}, {}, []
    # Numeric labels: preserve contiguous 0..N-1 if present; else remap to 0..N-1
    if pd.api.types.is_numeric_dtype(s.dtype):
        try:
            vals = sorted(set(int(v) for v in s.unique()))
        except Exception:
            vals = sorted(set(pd.to_numeric(s, errors="coerce").dropna().astype(int).unique().tolist()))
        contiguous = (len(vals) > 0 and vals[0] == 0 and vals[-1] == len(vals)-1 and len(vals) == (vals[-1]+1))
        id2train = {v: (v if contiguous else i) for i, v in enumerate(vals)}
        train2id = {v: k for k, v in id2train.items()}
        return len(vals), id2train, train2id, vals
    # Categorical/string labels: map unique categories to 0..N-1 in sorted order
    labels = sorted({str(v) for v in s.astype(str).unique().tolist()})
    id2train = {lab: i for i, lab in enumerate(labels)}
    train2id = {i: lab for lab, i in id2train.items()}
    return len(labels), id2train, train2id, labels

def token_id_to_rowcol(token_id: int, grid_size: int = 16):
    return int(token_id // grid_size), int(token_id % grid_size)

def center_crop_16_to_14(grid16: np.ndarray):
    if grid16.shape != (16,16):
        raise ValueError(f"Expected (16,16) grid; got {grid16.shape}")
    return grid16[1:15, 1:15]

def set_requires_grad(m: nn.Module, flag: bool):
    for p in m.parameters():
        p.requires_grad = flag


# ========== Backbone wrappers ==========
class TimmpatchTokens(nn.Module):
    def __init__(self, vit_model: nn.Module):
        super().__init__()
        self.vit = vit_model
    def forward(self, x):
        y = self.vit.forward_features(x)
        if isinstance(y, dict) and "x_norm_patchtokens" in y:
            return y["x_norm_patchtokens"]               # [B,196,D]
        if isinstance(y, dict) and "x_norm" in y:
            return y["x_norm"][:, 1:, :]                 # [B,196,D]
        if y.ndim == 3 and y.size(1) >= 197:
            return y[:, 1:, :]                           # [B,196,D]
        raise RuntimeError("Backbone must expose patch tokens [B,1+196,D].")

def load_timm_backbone(model_name: str, timm_kwargs: Optional[Dict[str, Any]] = None):
    vit = timm.create_model(model_name, **(timm_kwargs or {}))
    embed_dim = vit.num_features
    return TimmpatchTokens(vit), embed_dim, vit



# ========== Dataset ==========
class TileTokenDataset(Dataset):
    def __init__(
        self,
        tiles_gdf: gpd.GeoDataFrame,
        tokens_gdf: gpd.GeoDataFrame,
        wsi_path: str,
        label_col: str,
        id2train_map: Dict[int,int],
        mean: List[float],
        std: List[float],
        ignore_index: int,
        background_id: int,
        min_fg_tokens: int,
        augment: bool,
        input_size: List[int],
        target_mpp: Optional[float],
        source_mpp: Optional[float],
    ):
        self.tiles = tiles_gdf.copy().set_index("patch_id")
        self.tokens = tokens_gdf.copy()
        self.label_col = label_col
        self.id2train = id2train_map
        self.mean, self.std = mean, std
        self.ignore_index = ignore_index
        self.background_id = background_id
        self.min_fg_tokens = min_fg_tokens
        self.augment = augment
        self.input_size = input_size
        self.target_mpp = target_mpp
        self.source_mpp = source_mpp

        # Augmentation pipeline (geometric)
        self.aug_transform = None
        if self.augment:
            self.aug_transform = T.Compose([
                T.RandomHorizontalFlip(p=0.5),
                T.RandomVerticalFlip(p=0.5),
                T.RandomRotation(degrees=180, interpolation=T.InterpolationMode.BILINEAR, fill=0),
            ])

        # Store path and lazy-open per worker to avoid pickling issues on Windows
        self.wsi_path = wsi_path
        self._slide = None
        self._source_mpp_for_run: Optional[float] = None

        grids, keep = [], []
        for tid, grp in self.tokens.groupby("patch_id", sort=False):
            G = np.full((16,16), self.ignore_index, dtype=np.int16)
            for _, r in grp.iterrows():
                tokid = int(r["token_id"])
                rr, cc = token_id_to_rowcol(tokid, 16)
                lbl = r[self.label_col]
                if pd.isna(lbl):
                    continue
                # Map raw label to training id supporting numeric and string labels
                try:
                    key = int(lbl)
                except Exception:
                    key = str(lbl)
                tr = self.id2train.get(key, None)
                if tr is None:
                    continue
                G[rr, cc] = tr
            G14 = center_crop_16_to_14(G)
            fg = int(np.sum((G14 != self.ignore_index) & (G14 != self.background_id)))
            if fg >= self.min_fg_tokens:
                grids.append((tid, G14))
                keep.append(tid)

        if not keep:
            raise ValueError("No tiles left after min_fg_tokens filtering; lower data.min_fg_tokens.")

        self.tile_ids = keep
        self.grid_map = {tid: G for tid, G in grids}

    def _get_slide(self):
        if self._slide is None:
            # Unified reader: supports OpenSlide WSIs and regular images
            try:
                from histotuner.utils import open_he_input as _open_he_input
                inp = _open_he_input(self.wsi_path)
            except Exception as e:
                raise RuntimeError(f"Failed to open input '{self.wsi_path}': {e}")
            self._slide = inp
            # Config override wins; else use detected average MPP (WSI only)
            src_mpp = self.source_mpp if self.source_mpp is not None else inp.source_mpp
            self._source_mpp_for_run = src_mpp
            if self._slide.kind == "wsi":
                if self.target_mpp is not None and self._source_mpp_for_run is not None:
                    scale = float(self._source_mpp_for_run) / float(self.target_mpp)
                    print(f"Rescaling WSI → source_mpp={self._source_mpp_for_run:.3f}µm, target_mpp={self.target_mpp:.3f}µm, scale={scale:.3f}")
                elif self.target_mpp is not None and self._source_mpp_for_run is None:
                    warnings.warn("Target MPP provided but source MPP missing; defaulting to no rescale.")
        return self._slide

    def __len__(self): return len(self.tile_ids)

    def _read_tile_resized(self, tile_poly):
        minx, miny, maxx, maxy = tile_poly.bounds
        inp = self._get_slide()
        H, W = int(self.input_size[0]), int(self.input_size[1])
        from histotuner.utils import read_he_region as _read_region
        if inp.kind == "wsi" and self.target_mpp is not None and self._source_mpp_for_run is not None:
            # Rescale read size to target_mpp for WSIs
            scale = float(self._source_mpp_for_run) / float(self.target_mpp)
            read_w = max(1, int(round(W * scale)))
            read_h = max(1, int(round(H * scale)))
            cx = 0.5 * (minx + maxx)
            cy = 0.5 * (miny + maxy)
            x0 = int(round(cx - read_w * 0.5))
            y0 = int(round(cy - read_h * 0.5))
            region = _read_region(inp, x0, y0, read_w, read_h)
        else:
            # Use polygon bounds; for images, this crops the PIL image
            x0, y0 = int(minx), int(miny)
            w, h = int(maxx - minx), int(maxy - miny)
            region = _read_region(inp, x0, y0, w, h)
        region_resized = T.Resize((H, W))(region)
        img = TF.to_tensor(region_resized)
        return img

    def __getitem__(self, i):
        tid = self.tile_ids[i]
        tile_poly = self.tiles.loc[tid, "geometry"]
        img = self._read_tile_resized(tile_poly)
        # Apply geometric augmentations first
        if self.augment and self.aug_transform is not None:
            img = self.aug_transform(img)
        # Existing color jitter
        if self.augment:
            img = TF.adjust_brightness(img, 1.0 + float(torch.empty(1).uniform_(-0.1, 0.1)))
            img = TF.adjust_contrast(img,  1.0 + float(torch.empty(1).uniform_(-0.1, 0.1)))
        img = TF.normalize(img, self.mean, self.std)
        y = torch.from_numpy(self.grid_map[tid]).long().view(-1)  # [196]
        return img, y, tid

def collate(batch):
    xs, ys, tids = zip(*batch)
    return torch.stack(xs,0), torch.stack(ys,0), list(tids)

# New: multi-slide dataset that reads tiles from multiple WSIs
class MultiSlideTileTokenDataset(Dataset):
    def __init__(
        self,
        tiles_gdf: gpd.GeoDataFrame,
        tokens_gdf: gpd.GeoDataFrame,
        wsi_map: Dict[str, str],
        label_col: str,
        id2train_map: Dict[int,int],
        mean: List[float],
        std: List[float],
        ignore_index: int,
        background_id: int,
        min_fg_tokens: int,
        augment: bool,
        input_size: List[int],
        target_mpp_map: Optional[Dict[str, float]] = None,
        source_mpp_map: Optional[Dict[str, float]] = None,
    ):
        # Expect columns: 'patch_id', 'global_patch_id', 'slide_id', 'geometry' in tiles
        # and 'patch_id', 'global_patch_id', 'slide_id', 'token_id', label_col in tokens
        self.tiles = tiles_gdf.copy().set_index("global_patch_id")
        self.tokens = tokens_gdf.copy()
        self.label_col = label_col
        self.id2train = id2train_map
        self.mean, self.std = mean, std
        self.ignore_index = ignore_index
        self.background_id = background_id
        self.min_fg_tokens = min_fg_tokens
        self.augment = augment
        self.input_size = input_size
        self.target_mpp_map = target_mpp_map or {}
        self.source_mpp_map = source_mpp_map or {}

        # Augmentation pipeline (geometric)
        self.aug_transform = None
        if self.augment:
            self.aug_transform = T.Compose([
                T.RandomHorizontalFlip(p=0.5),
                T.RandomVerticalFlip(p=0.5),
                T.RandomRotation(degrees=180, interpolation=T.InterpolationMode.BILINEAR, fill=0),
            ])

        # Per-slide readers and cached MPPs
        self.wsi_map = dict(wsi_map)
        self._slides: Dict[str, Any] = {}
        self._source_mpp_for_run: Dict[str, Optional[float]] = {}

        # Build token grids per global_patch_id
        grids, keep = [], []
        for gid, grp in self.tokens.groupby("global_patch_id", sort=False):
            G = np.full((16,16), self.ignore_index, dtype=np.int16)
            for _, r in grp.iterrows():
                tokid = int(r["token_id"])
                rr, cc = token_id_to_rowcol(tokid, 16)
                lbl = r[self.label_col]
                if pd.isna(lbl):
                    continue
                # Map raw label to training id supporting numeric and string labels
                try:
                    key = int(lbl)
                except Exception:
                    key = str(lbl)
                tr = self.id2train.get(key, None)
                if tr is None:
                    continue
                G[rr, cc] = tr
            G14 = center_crop_16_to_14(G)
            fg = int(np.sum((G14 != self.ignore_index) & (G14 != self.background_id)))
            if fg >= self.min_fg_tokens:
                grids.append((gid, G14))
                keep.append(gid)

        if not keep:
            raise ValueError("No tiles left after min_fg_tokens filtering in multi-slide dataset.")

        self.tile_ids = keep
        self.grid_map = {gid: G for gid, G in grids}

    def _get_slide(self, slide_id: str):
        if slide_id not in self._slides:
            wsi_path = self.wsi_map[slide_id]
            # Unified reader: supports OpenSlide WSIs and regular images
            try:
                from histotuner.utils import open_he_input as _open_he_input
                inp = _open_he_input(wsi_path)
            except Exception as e:
                raise RuntimeError(f"[{slide_id}] Failed to open input '{wsi_path}': {e}")
            self._slides[slide_id] = inp
            # Config override wins; else use detected average MPP (WSI only)
            src_mpp = self.source_mpp_map.get(slide_id, None)
            self._source_mpp_for_run[slide_id] = src_mpp if src_mpp is not None else inp.source_mpp
            tgt_mpp = self.target_mpp_map.get(slide_id, None)
            if inp.kind == "wsi":
                if tgt_mpp is not None and self._source_mpp_for_run[slide_id] is not None:
                    scale = float(self._source_mpp_for_run[slide_id]) / float(tgt_mpp)
                    print(f"[{slide_id}] Rescaling WSI → source_mpp={self._source_mpp_for_run[slide_id]:.3f}µm, target_mpp={tgt_mpp:.3f}µm, scale={scale:.3f}")
                elif tgt_mpp is not None and self._source_mpp_for_run[slide_id] is None:
                    warnings.warn(f"[{slide_id}] Target MPP provided but source MPP missing; defaulting to no rescale.")
        return self._slides[slide_id]

    def __len__(self):
        return len(self.tile_ids)

    def _read_tile_resized(self, tile_poly, slide_id: str):
        minx, miny, maxx, maxy = tile_poly.bounds
        inp = self._get_slide(slide_id)
        H, W = int(self.input_size[0]), int(self.input_size[1])
        from histotuner.utils import read_he_region as _read_region
        tgt_mpp = self.target_mpp_map.get(slide_id, None)
        src_mpp = self._source_mpp_for_run.get(slide_id, None)
        if inp.kind == "wsi" and tgt_mpp is not None and src_mpp is not None:
            scale = float(src_mpp) / float(tgt_mpp)
            read_w = max(1, int(round(W * scale)))
            read_h = max(1, int(round(H * scale)))
            cx = 0.5 * (minx + maxx)
            cy = 0.5 * (miny + maxy)
            x0 = int(round(cx - read_w * 0.5))
            y0 = int(round(cy - read_h * 0.5))
            region = _read_region(inp, x0, y0, read_w, read_h)
        else:
            x0, y0 = int(minx), int(miny)
            w, h = int(maxx - minx), int(maxy - miny)
            region = _read_region(inp, x0, y0, w, h)
        region_resized = T.Resize((H, W))(region)
        img = TF.to_tensor(region_resized)
        return img

    def __getitem__(self, i):
        gid = self.tile_ids[i]
        tile_row = self.tiles.loc[gid]
        slide_id = str(tile_row["slide_id"])  # tiles must carry slide_id
        tile_poly = tile_row["geometry"]
        img = self._read_tile_resized(tile_poly, slide_id)
        if self.augment and self.aug_transform is not None:
            img = self.aug_transform(img)
        if self.augment:
            img = TF.adjust_brightness(img, 1.0 + float(torch.empty(1).uniform_(-0.1, 0.1)))
            img = TF.adjust_contrast(img,  1.0 + float(torch.empty(1).uniform_(-0.1, 0.1)))
        img = TF.normalize(img, self.mean, self.std)
        y = torch.from_numpy(self.grid_map[gid]).long().view(-1)  # [196]
        return img, y, gid


# ========== Model head ==========
def center_crop_tokens(tokens: torch.Tensor, target_size: int = 14) -> torch.Tensor:
    # tokens: [B, T, D]; handle non-square T by trimming to nearest square
    B, T, D = tokens.shape
    sN = int(math.isqrt(T))
    if sN < target_size:
        return tokens  # too few tokens to crop; return as-is
    T_sq = sN * sN
    g = tokens[:, :T_sq, :].view(B, sN, sN, D)
    start = (sN - target_size) // 2
    end = start + target_size
    g = g[:, start:end, start:end, :]
    return g.reshape(B, target_size * target_size, D)

class TokenClassifier(nn.Module):
    def __init__(self, backbone: nn.Module, embed_dim: int, n_classes: int,
                 head_hidden_dim: Optional[int] = None, head_dropout: float = 0.1):
        super().__init__()
        self.backbone = backbone
        hdim = head_hidden_dim or embed_dim
        self.head = nn.Sequential(
            nn.Linear(embed_dim, hdim),
            nn.ReLU(inplace=True),
            nn.Dropout(p=head_dropout),
            nn.Linear(hdim, n_classes),
        )
    def forward(self, x):
        toks = self.backbone(x)      # [B,*,D]
        toks = center_crop_tokens(toks, target_size=14)  # force to 14x14 grid
        return self.head(toks)       # [B,196,C]


# ========== Training ==========
def class_weights_from_loader(loader, n_classes, ignore_index):
    counts = np.zeros(n_classes, dtype=np.int64)
    with torch.no_grad():
        for _, y, _ in loader:
            y_np = y.numpy()
            for c in range(n_classes):
                counts[c] += (y_np == c).sum()
    counts = np.maximum(counts, 1)
    inv = 1.0 / counts
    inv /= inv.mean()
    return torch.tensor(inv, dtype=torch.float32)

@torch.no_grad()
def eval_macro_f1(model, loader, n_classes, ignore_index, background_id):
    from sklearn.metrics import f1_score
    model.eval()
    ys, ps = [], []
    for x, y, _ in loader:
        x = x.to(DEVICE); y = y.to(DEVICE)
        pred = model(x).argmax(-1)  # [B,196]
        mask = (y != ignore_index)
        if mask.any():
            ys.append(y[mask].cpu().numpy()); ps.append(pred[mask].cpu().numpy())
    if not ys: return 0.0
    y = np.concatenate(ys); p = np.concatenate(ps)
    labels_wo_bg = [c for c in range(n_classes) if c != background_id] if n_classes > 1 else list(range(n_classes))
    return float(f1_score(y, p, labels=labels_wo_bg, average="macro", zero_division=0))

def run_training(cfg: Config):
    torch.manual_seed(cfg.optim.seed); np.random.seed(cfg.optim.seed)

    # Optional Weights & Biases logging
    use_wandb = False
    if getattr(cfg.save, "wandb_enabled", False):
        try:
            import wandb
            use_wandb = True
            wandb.init(
                project=(cfg.save.wandb_project or "histotuner"),
                name=cfg.save.wandb_run_name,
                config={
                    "data": vars(cfg.data),
                    "model": vars(cfg.model),
                    "optim": vars(cfg.optim),
                },
            )
        except Exception as e:
            print(f"Warning: W&B init failed ({e}); proceeding without logging.")
            use_wandb = False

    # Normalize to slides-only flow
    slides_cfg = getattr(cfg.data, "slides", None)
    if not slides_cfg or len(slides_cfg) == 0:
        if cfg.data.sdata_path and cfg.data.wsi_path:
            slides_cfg = [{
                "sdata_path": cfg.data.sdata_path,
                "wsi_path": cfg.data.wsi_path,
            }]
        else:
            raise ValueError("Config must provide 'data.slides' entries or top-level 'sdata_path' and 'wsi_path'.")

    pooled_tiles = []
    pooled_tokens = []
    wsi_map: Dict[str, str] = {}
    tgt_mpp_map: Dict[str, float] = {}
    src_mpp_map: Dict[str, float] = {}

    for entry in slides_cfg:
        slide_id = str(entry.get("slide_id") or Path(entry.get("sdata_path")).stem)
        sdata_path = entry["sdata_path"]
        wsi_path = entry["wsi_path"]
        tiles_key = entry.get("tiles_key", cfg.data.tiles_key)
        tokens_key = entry.get("tokens_key", cfg.data.tokens_key)
        label_col = entry.get("label_col", cfg.data.label_col)
        t_mpp = entry.get("target_mpp", cfg.data.target_mpp)
        s_mpp = entry.get("source_mpp", cfg.data.source_mpp)
        sdata = SpatialData.read(sdata_path)
        tiles_i = sdata.shapes[tiles_key].copy()
        tokens_i = sdata.shapes[tokens_key].copy()
        # Ensure label column name matches global cfg.data.label_col
        if label_col != cfg.data.label_col:
            tokens_i = tokens_i.rename(columns={label_col: cfg.data.label_col})
        # Tag with slide_id and create a global_patch_id for uniqueness
        tiles_i["slide_id"] = slide_id
        tokens_i["slide_id"] = slide_id
        tiles_i["global_patch_id"] = tiles_i["patch_id"].astype(str).radd(f"{slide_id}__")
        tokens_i["global_patch_id"] = tokens_i["patch_id"].astype(str).radd(f"{slide_id}__")
        pooled_tiles.append(tiles_i)
        pooled_tokens.append(tokens_i)
        wsi_map[slide_id] = wsi_path
        if t_mpp is not None:
            tgt_mpp_map[slide_id] = float(t_mpp)
        if s_mpp is not None:
            src_mpp_map[slide_id] = float(s_mpp)

    tiles = pd.concat(pooled_tiles, ignore_index=True)
    tokens = pd.concat(pooled_tokens, ignore_index=True)

    # Validations
    for col in ["patch_id"]:
        assert col in tiles.columns, f"'{col}' missing in tiles"
        assert col in tokens.columns, f"'{col}' missing in tokens"
    assert "token_id" in tokens.columns, "'token_id' missing in tokens"
    assert cfg.data.label_col in tokens.columns, f"'{cfg.data.label_col}' missing in tokens"

    # Classes
    n_classes, id2train, train2id, raw = infer_classes(tokens[cfg.data.label_col])
    print(f"Detected {n_classes} classes from '{cfg.data.label_col}': {raw}")

    # Dataset / splits: always use MultiSlideTileTokenDataset
    ds = MultiSlideTileTokenDataset(
        tiles_gdf=tiles,
        tokens_gdf=tokens,
        wsi_map=wsi_map,
        label_col=cfg.data.label_col,
        id2train_map=id2train,
        mean=cfg.data.mean, std=cfg.data.std,
        ignore_index=cfg.data.ignore_index,
        background_id=cfg.data.background_id,
        min_fg_tokens=cfg.data.min_fg_tokens,
        augment=cfg.data.augment,
        input_size=cfg.data.input_size,
        target_mpp_map=tgt_mpp_map,
        source_mpp_map=src_mpp_map,
    )

    n = len(ds); idx = np.arange(n); np.random.shuffle(idx)
    n_val = max(1, int(round(cfg.optim.val_fraction * n)))
    val_idx, train_idx = idx[:n_val], idx[n_val:]
    dl_tr = DataLoader(Subset(ds, train_idx.tolist()), batch_size=cfg.optim.batch_size, shuffle=True, num_workers=cfg.optim.num_workers, pin_memory=True, collate_fn=collate, persistent_workers=(cfg.optim.num_workers>0))
    dl_va = DataLoader(Subset(ds, val_idx.tolist()), batch_size=cfg.optim.batch_size, shuffle=False, num_workers=cfg.optim.num_workers, pin_memory=True, collate_fn=collate, persistent_workers=(cfg.optim.num_workers>0))

    # Backbone (TIMM-only)
    backbone, embed_dim, raw_backbone = load_timm_backbone(cfg.model.model_id, cfg.model.timm_kwargs)
    if cfg.model.checkpoint_path:
        print(f"Loading checkpoint into timm backbone: {cfg.model.checkpoint_path}")
        sd = safetensors_load(cfg.model.checkpoint_path) if cfg.model.checkpoint_path.endswith(".safetensors") \
             else torch.load(cfg.model.checkpoint_path, map_location="cpu")
        if isinstance(sd, dict) and "state_dict" in sd: sd = sd["state_dict"]
        missing, unexpected = raw_backbone.load_state_dict(sd, strict=False)
        print(f"Loaded (non-strict): missing={len(missing)}, unexpected={len(unexpected)}")

    model = TokenClassifier(
        backbone, embed_dim, n_classes,
        head_hidden_dim=cfg.model.head_hidden_dim,
        head_dropout=cfg.model.head_dropout,
    ).to(DEVICE)

    # Loss
    class_w = class_weights_from_loader(dl_tr, n_classes, cfg.data.ignore_index).to(DEVICE)
    criterion = nn.CrossEntropyLoss(weight=class_w, ignore_index=cfg.data.ignore_index)

    # AMP setup
    use_amp = HAS_AMP and DEVICE == "cuda"
    scaler = amp.GradScaler('cuda', enabled=use_amp)

    # Stage 1: linear probe
    if cfg.model.freeze_stage1:
        set_requires_grad(model.backbone, False)
        set_requires_grad(model.head, True)
        opt = torch.optim.AdamW(model.parameters(), lr=cfg.optim.lr_head, weight_decay=cfg.optim.weight_decay)
        # Stage 1 scheduler
        sched_type = (cfg.optim.stage1_sched or "none").lower()
        sched = None
        if sched_type == "onecycle":
            max_lr = cfg.optim.stage1_max_lr or cfg.optim.lr_head
            try:
                sched = torch.optim.lr_scheduler.OneCycleLR(opt, max_lr=max_lr,
                                                            steps_per_epoch=len(dl_tr), epochs=cfg.optim.epochs_head)
            except Exception as e:
                print(f"Warning: OneCycleLR init failed ({e}); falling back to no scheduler.")
                sched = None
        elif sched_type == "cosine":
            sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=cfg.optim.epochs_head,
                                                               eta_min=cfg.optim.lr_head * 0.1)

        def run_epoch(loader, train_mode=True):
            model.train(train_mode)
            tot_loss, tot_seen, tot_ok = 0.0, 0, 0
            iterator = (tqdm(loader, desc="S1 train", total=len(loader)) if (HAS_TQDM and train_mode) else
                        tqdm(loader, desc="S1 val", total=len(loader)) if HAS_TQDM else loader)
            for x, y, _ in iterator:
                x = x.to(DEVICE); y = y.to(DEVICE)
                ctx = amp.autocast(device_type="cuda", dtype=torch.float16) if use_amp else nullcontext()
                with ctx:
                    logits = model(x)
                    loss = criterion(logits.view(-1, n_classes), y.view(-1))
                if train_mode:
                    opt.zero_grad(set_to_none=True)
                    if use_amp:
                        scaler.scale(loss).backward()
                        scaler.unscale_(opt)
                        torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.optim.grad_clip)
                        scaler.step(opt)
                        scaler.update()
                    else:
                        loss.backward()
                        torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.optim.grad_clip)
                        opt.step()
                    # step OneCycle per batch
                    if sched is not None and sched_type == "onecycle":
                        try:
                            sched.step()
                        except Exception:
                            pass
                if HAS_TQDM:
                    try:
                        lr0 = float(opt.param_groups[0].get("lr", cfg.optim.lr_head))
                        iterator.set_postfix(loss=float(loss.item()), lr=lr0)
                    except Exception:
                        pass
                tot_loss += float(loss.item()) * x.size(0)
                with torch.no_grad():
                    mask = (y != cfg.data.ignore_index)
                    if mask.any():
                        pred = logits.argmax(-1)
                        tot_ok += int((pred[mask] == y[mask]).sum().item())
                        tot_seen += int(mask.sum().item())
            return tot_loss/max(1,len(loader.dataset)), (tot_ok/max(1,tot_seen if tot_seen>0 else 1))

        print("Stage 1: linear probe...")
        for e in range(cfg.optim.epochs_head):
            tr_loss, tr_acc = run_epoch(dl_tr, True)
            va_loss, va_acc = run_epoch(dl_va, False)
            # step cosine per epoch
            if sched is not None and sched_type == "cosine":
                try:
                    sched.step()
                except Exception:
                    pass
            f1 = eval_macro_f1(model, dl_va, n_classes, cfg.data.ignore_index, cfg.data.background_id)
            print(f"[S1 {e+1}/{cfg.optim.epochs_head}] tr_loss={tr_loss:.4f} acc={tr_acc:.3f} | "
                  f"va_loss={va_loss:.4f} acc={va_acc:.3f} | macroF1(!bg)={f1:.3f}")
            if use_wandb:
                try:
                    import wandb
                    lr = float(opt.param_groups[0]["lr"]) if len(opt.param_groups) > 0 else cfg.optim.lr_head
                    wandb.log({
                        "s1/train_loss": tr_loss,
                        "s1/train_acc": tr_acc,
                        "s1/val_loss": va_loss,
                        "s1/val_acc": va_acc,
                        "s1/macroF1": f1,
                        "s1/lr": lr,
                        "s1/epoch": e+1,
                    })
                except Exception:
                    pass

    # Stage 2: fine-tune with differential LR
    set_requires_grad(model.backbone, True)
    opt = torch.optim.AdamW([
        {"params": model.backbone.parameters(), "lr": cfg.optim.lr_backbone},
        {"params": model.head.parameters(),     "lr": cfg.optim.lr_head},
    ], weight_decay=cfg.optim.weight_decay)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=cfg.optim.epochs_full,
                                                       eta_min=min(cfg.optim.lr_head, cfg.optim.lr_backbone)*0.1)

    def run_epoch2(loader, train_mode=True):
        model.train(train_mode)
        tot_loss, tot_seen, tot_ok = 0.0, 0, 0
        iterator = (tqdm(loader, desc="S2 train", total=len(loader)) if (HAS_TQDM and train_mode) else
                    tqdm(loader, desc="S2 val", total=len(loader)) if HAS_TQDM else loader)
        for x, y, _ in iterator:
            x = x.to(DEVICE); y = y.to(DEVICE)
            ctx = amp.autocast(device_type="cuda", dtype=torch.float16) if use_amp else nullcontext()
            with ctx:
                logits = model(x)
                loss = criterion(logits.view(-1, n_classes), y.view(-1))
            if train_mode:
                opt.zero_grad(set_to_none=True)
                if use_amp:
                    scaler.scale(loss).backward()
                    scaler.unscale_(opt)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.optim.grad_clip)
                    scaler.step(opt)
                    scaler.update()
                else:
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.optim.grad_clip)
                    opt.step()
            if HAS_TQDM:
                try:
                    iterator.set_postfix(loss=float(loss.item()))
                except Exception:
                    pass
            tot_loss += float(loss.item()) * x.size(0)
            with torch.no_grad():
                mask = (y != cfg.data.ignore_index)
                if mask.any():
                    pred = logits.argmax(-1)
                    tot_ok += int((pred[mask] == y[mask]).sum().item())
                    tot_seen += int(mask.sum().item())
        return tot_loss/max(1,len(loader.dataset)), (tot_ok/max(1,tot_seen if tot_seen>0 else 1))

    if DEVICE == "cuda":
        torch.cuda.empty_cache()
    print("Stage 2: fine-tuning...")
    best_f1 = 0.0
    epochs_without_improve = 0
    patience = cfg.optim.early_stopping_patience
    for e in range(cfg.optim.epochs_full):
        tr_loss, tr_acc = run_epoch2(dl_tr, True)
        va_loss, va_acc = run_epoch2(dl_va, False)
        f1 = eval_macro_f1(model, dl_va, n_classes, cfg.data.ignore_index, cfg.data.background_id)
        # Early stopping check on macroF1
        if f1 > best_f1:
            best_f1 = f1
            epochs_without_improve = 0
        else:
            epochs_without_improve += 1
        sched.step()
        print(f"[S2 {e+1}/{cfg.optim.epochs_full}] tr_loss={tr_loss:.4f} acc={tr_acc:.3f} | "
              f"va_loss={va_loss:.4f} acc={va_acc:.3f} | macroF1(!bg)={f1:.3f}")
        if use_wandb:
            try:
                import wandb
                lr_backbone = float(opt.param_groups[0]["lr"]) if len(opt.param_groups) > 0 else cfg.optim.lr_backbone
                lr_head = float(opt.param_groups[1]["lr"]) if len(opt.param_groups) > 1 else cfg.optim.lr_head
                wandb.log({
                    "s2/train_loss": tr_loss,
                    "s2/train_acc": tr_acc,
                    "s2/val_loss": va_loss,
                    "s2/val_acc": va_acc,
                    "s2/macroF1": f1,
                    "s2/lr_backbone": lr_backbone,
                    "s2/lr_head": lr_head,
                    "s2/epoch": e+1,
                })
            except Exception:
                pass
        if patience > 0 and epochs_without_improve >= patience:
            print(f"Early stopping triggered: no macroF1 improvement for {patience} epochs (best={best_f1:.3f}).")
            break

    # Save
    os.makedirs(cfg.save.output_path, exist_ok=True)
    ckpt_path = str(Path(cfg.save.output_path) / "finetuned_token_classifier.pt")
    torch.save({
        "state_dict": model.state_dict(),
        "backbone": cfg.model.model_id,
        "n_classes": n_classes,
        "label_column": cfg.data.label_col,
        "id2train": id2train,
        "train2id": train2id,
        "normalization": {"mean": cfg.data.mean, "std": cfg.data.std},
        "ignore_index": cfg.data.ignore_index,
        "background_id": cfg.data.background_id,
        "best_val_macroF1_wo_bg": best_f1,
    }, ckpt_path)

    with open(Path(cfg.save.output_path) / "label_mapping.json", "w") as f:
        json.dump({
            "raw_classes": raw,
            "id2train": id2train,
            "train2id": train2id
        }, f, indent=2)

    print(f"Saved: {ckpt_path}")
    print(f"Saved: {Path(cfg.save.output_path) / 'label_mapping.json'}")
    # Finish W&B run if active
    try:
        if use_wandb:
            import wandb
            wandb.finish()
    except Exception:
        pass


# ========== Public entrypoint ==========
def run_from_config(config_path: str):
    cfg = load_config(config_path)
    run_training(cfg)

# Notebook-friendly entrypoints
from typing import Union as _Union

def _config_from_dict(cfg_dict: Dict[str, Any]) -> Config:
    """Construct Config dataclasses from a plain dict (notebook use)."""
    if not all(k in cfg_dict for k in ("data", "model", "optim", "save")):
        raise ValueError("Config dict must have 'data', 'model', 'optim', and 'save' sections.")
    return Config(
        data=DataConfig(**cfg_dict["data"]),
        model=ModelConfig(**cfg_dict["model"]),
        optim=OptimConfig(**cfg_dict["optim"]),
        save=SaveConfig(**cfg_dict["save"]),
    )

def runTrainingPipeline(cfg: _Union[Config, Dict[str, Any]]) -> None:
    """Run training from either Config dataclasses or a plain dict (Jupyter-friendly)."""
    if isinstance(cfg, dict):
        cfg = _config_from_dict(cfg)
    elif not isinstance(cfg, Config):
        raise TypeError("cfg must be a Config or a dict.")
    run_training(cfg)

def runTrainingFromYaml(config_path: str) -> None:
    """Convenience wrapper to load YAML and run training (Jupyter-friendly)."""
    cfg = load_config(config_path)
    run_training(cfg)

# ========== CLI ==========
if __name__ == "__main__":
    import argparse
    ap = argparse.ArgumentParser("H-optimus token-level fine-tuning")
    ap.add_argument("--config", required=True, help="Path to YAML config")
    args = ap.parse_args()
    run_from_config(args.config)
