#!/usr/bin/env python3
"""
crop_svs_with_tifffile.py

Center-crop an SVS (level-0) and re-save as an SVS-like pyramidal JPEG-tiled BigTIFF.
- Ensures data written are contiguous uint8 HxWx3 arrays (fixes JPEG encode errors).
- Preserves input resolution unit (INCH/CENTIMETER/NONE) and writes X/YResolution from MPP.
- No libvips required.

Usage:
  python crop_svs_with_tifffile.py "path/to/slide.svs" --width 672 --height 672
    [--levels 6] [--tile 256] [--jpeg-q 90] [--ext .svs] [--snap-to-tile]
"""

import argparse
import math
from pathlib import Path
import sys

def center_crop_bounds(W, H, cw, ch):
    cx, cy = W // 2, H // 2
    x0 = max(0, cx - cw // 2)
    y0 = max(0, cy - ch // 2)
    if x0 + cw > W: x0 = max(0, W - cw)
    if y0 + ch > H: y0 = max(0, H - ch)
    return int(x0), int(y0), int(cw), int(ch)

def snap_dim_to_tile(v, tile):
    if tile <= 0:
        return v
    return int(math.ceil(v / tile) * tile)

def decode_tiff_resunit(props):
    unit_map = {"1": "NONE", "2": "INCH", "3": "CENTIMETER"}
    ru = props.get("tiff.ResolutionUnit")
    if ru in unit_map:
        return unit_map[ru]
    if ru in {"INCH", "CENTIMETER", "NONE"}:
        return ru
    return None

def compute_resolution_from_mpp(mpp_x, mpp_y, unit):
    if mpp_x is None or mpp_y is None or unit is None or unit == "NONE":
        return None, None
    try:
        mpp_x = float(mpp_x)
        mpp_y = float(mpp_y)
        if mpp_x <= 0 or mpp_y <= 0:
            return None, None
        if unit == "CENTIMETER":
            return 10000.0 / mpp_x, 10000.0 / mpp_y  # 1 cm = 10,000 µm
        if unit == "INCH":
            return 25400.0 / mpp_x, 25400.0 / mpp_y  # 1 in = 25,400 µm
    except Exception:
        pass
    return None, None

def parse_float_prop(props, key):
    try:
        return float(props.get(key)) if props.get(key) is not None else None
    except Exception:
        return None

def pil_to_contig_array(pil_img):
    import numpy as np
    # Ensure RGB, uint8, C-contiguous HxWx3
    arr = np.array(pil_img, dtype=np.uint8, copy=False)
    if arr.ndim != 3 or arr.shape[2] != 3:
        raise ValueError("Expected an RGB image (H x W x 3).")
    return np.ascontiguousarray(arr)

def build_pyramid_arrays(pil_img, levels):
    # Downsample by ~2x per level using PIL BOX filter, return contiguous arrays
    from PIL import Image
    arrays = []
    current = pil_img
    for i in range(max(1, levels)):
        arrays.append(pil_to_contig_array(current))
        w, h = current.size
        if min(w, h) < 2:
            break
        nw, nh = max(1, w // 2), max(1, h // 2)
        if (nw, nh) == (w, h):
            break
        current = current.resize((nw, nh), resample=Image.BOX)
    return arrays

def build_pyramid_arrays_nd(arr, levels):
    import numpy as np
    arrays = []
    current = np.ascontiguousarray(arr)
    for i in range(max(1, levels)):
        arrays.append(current)
        h, w = current.shape[:2]
        if min(h, w) < 2:
            break
        nh, nw = max(1, h // 2), max(1, w // 2)
        if (nh, nw) == (h, w):
            break
        hh = (h // 2) * 2
        ww = (w // 2) * 2
        trimmed = current[:hh, :ww]
        if trimmed.ndim == 3:
            trimmed = trimmed.reshape(hh // 2, 2, ww // 2, 2, trimmed.shape[2]).mean(axis=(1, 3))
        else:
            trimmed = trimmed.reshape(hh // 2, 2, ww // 2, 2).mean(axis=(1, 3))
        # Round for integer dtypes
        if hasattr(current, 'dtype'):
            import numpy as _np
            if _np.issubdtype(current.dtype, _np.integer):
                trimmed = _np.rint(trimmed)
            trimmed = trimmed.astype(current.dtype)
        current = trimmed
    return arrays

def write_ifd(tif, arr, *, tile, q, is_level0, subifds_remaining, desc, xres, yres, resunit, datetime=None, software=None, metadata=None):
    import numpy as np
    # Determine photometric and compression based on array shape and dtype
    nch = arr.shape[2] if (hasattr(arr, 'shape') and len(arr.shape) == 3) else 1
    is_uint8 = hasattr(arr, 'dtype') and arr.dtype == np.uint8
    photometric = "rgb" if nch == 3 else "minisblack"
    compression = "jpeg" if (is_uint8 and nch in (1, 3)) else "zlib"

    # Common kwargs
    kwargs_common = dict(
        photometric=photometric,
        tile=(tile, tile),
        compression=compression,
        metadata=metadata if metadata else None,
        resolution=(xres, yres) if (xres and yres) else None,
        resolutionunit=resunit if (xres and yres and resunit) else None,
        software=software if software else None,
    )
    if is_level0:
        kwargs_common.update(dict(
            subifds=subifds_remaining,
            description=desc,
            datetime=datetime if datetime else None,
        ))
    else:
        kwargs_common.update(dict(
            subfiletype=1,
        ))

    # Write with appropriate compression args
    if compression == "jpeg":
        try:
            tif.write(arr, compressionargs={"level": int(q)}, **kwargs_common)
        except TypeError:
            tif.write(arr, compressionargs={"jpegquality": int(q)}, **kwargs_common)
    else:
        tif.write(arr, **kwargs_common)

def cropImage(inputPath: str, width: int, height: int, levels: int = 6, tile: int = 256, jpegQ: int = 90, ext: str | None = None, snapToTile: bool = False):
    """Notebook-friendly API to center-crop an SVS/TIFF/OME-TIFF and save a pyramidal BigTIFF.

    Parameters use camelCase for consistency.
    Returns (outPath, metaDict) for programmatic use.
    """
    in_path = Path(inputPath).resolve()
    if not in_path.exists():
        raise FileNotFoundError("input must be an existing image file")

    # Try unified H&E input opener to handle both WSI and regular TIFF/OME-TIFF
    from histotuner.utils import open_he_input as _open_he
    he_input = _open_he(str(in_path))

    # Compute crop bounds
    W0, H0 = he_input.size
    x0, y0, cw, ch = center_crop_bounds(W0, H0, width, height)

    if snapToTile and tile > 0:
        cw = snap_dim_to_tile(cw, tile)
        ch = snap_dim_to_tile(ch, tile)
        x0, y0, cw, ch = center_crop_bounds(W0, H0, cw, ch)

    # Resolution metadata
    resunit_in = None
    mpp_x = None
    mpp_y = None
    xres = None
    yres = None

    # Preserve original description, DateTime, Software if available
    orig_desc, orig_datetime, orig_software = None, None, None

    # Helper to safely convert TIFF rationals to float
    def _ratio_to_float(val):
        try:
            # tifffile Ratio
            num = getattr(val, 'numerator', None)
            den = getattr(val, 'denominator', None)
            if num is not None and den is not None:
                return float(num) / float(den) if float(den) != 0 else None
            # tuple/list
            if isinstance(val, (tuple, list)) and len(val) >= 2:
                return float(val[0]) / float(val[1]) if float(val[1]) != 0 else None
            return float(val)
        except Exception:
            return None

    # Helper to convert OME units to microns
    def _to_um(val, unit):
        try:
            u = (unit or "").lower()
            if u in ("µm", "um", "micrometer", "micron", "micrometre"):
                return float(val)
            if u in ("nm", "nanometer", "nanometre"):
                return float(val) / 1000.0
            if u in ("mm", "millimeter", "millimetre"):
                return float(val) * 1000.0
            if u in ("cm", "centimeter", "centimetre"):
                return float(val) * 10000.0
            if u in ("m", "meter", "metre"):
                return float(val) * 1e6
            return float(val)
        except Exception:
            return None

    if he_input.kind == "wsi":
        # Whole-slide via OpenSlide
        slide = he_input.slide_handle
        # Read resolution unit and mpp from properties
        resunit_in = decode_tiff_resunit(slide.properties)
        mpp_x = parse_float_prop(slide.properties, "openslide.mpp-x")
        mpp_y = parse_float_prop(slide.properties, "openslide.mpp-y")
        xres, yres = compute_resolution_from_mpp(mpp_x, mpp_y, resunit_in)
        # Fallback to numeric X/Y resolution in TIFF properties if available
        xr_prop = parse_float_prop(slide.properties, "tiff.XResolution")
        yr_prop = parse_float_prop(slide.properties, "tiff.YResolution")
        if (xres is None or yres is None) and xr_prop and yr_prop and xr_prop > 0 and yr_prop > 0:
            xres, yres = xr_prop, yr_prop
        # If still missing numeric resolution but MPP exists and unit was missing/NONE, default to INCH and derive DPI
        if (xres is None or yres is None) and (mpp_x is not None and mpp_y is not None):
            if resunit_in is None or resunit_in == "NONE":
                fallback_unit = "INCH"
                fx, fy = compute_resolution_from_mpp(mpp_x, mpp_y, fallback_unit)
                if fx is not None and fy is not None:
                    xres, yres = fx, fy
                    resunit_in = fallback_unit
        # Read region
        region = slide.read_region((x0, y0), 0, (cw, ch)).convert("RGB")
        # Try to preserve description/DateTime/Software from original TIFF if readable
        try:
            from tifffile import TiffFile
            with TiffFile(str(in_path)) as tf:
                page0 = tf.pages[0]
                orig_desc = getattr(page0, "description", None)
                orig_datetime = getattr(page0, "datetime", None)
                orig_software = getattr(page0, "software", None)
                if orig_datetime is None:
                    try:
                        orig_datetime = page0.tags.get("DateTime").value
                    except Exception:
                        pass
                if orig_software is None:
                    try:
                        orig_software = page0.tags.get("Software").value
                    except Exception:
                        pass
        except Exception:
            pass
    else:
        # Regular image via TIFF/PIL; prefer tifffile ndarray to preserve channels and axes
        region_arr = None
        try:
            from tifffile import TiffFile
            import numpy as np
            with TiffFile(str(in_path)) as tf:
                page0 = tf.pages[0]
                # Numeric resolution
                try:
                    xr_tag = page0.tags.get("XResolution")
                    yr_tag = page0.tags.get("YResolution")
                    ru_tag = page0.tags.get("ResolutionUnit")
                    if xr_tag:
                        xres = _ratio_to_float(xr_tag.value)
                    if yr_tag:
                        yres = _ratio_to_float(yr_tag.value)
                    if ru_tag:
                        try:
                            ru_val = int(ru_tag.value)
                            resunit_in = {1: "NONE", 2: "INCH", 3: "CENTIMETER"}.get(ru_val, None)
                        except Exception:
                            resunit_in = None
                except Exception:
                    pass
                # Preserve description/DateTime/Software
                try:
                    orig_desc = getattr(page0, "description", None)
                    orig_datetime = getattr(page0, "datetime", None)
                    orig_software = getattr(page0, "software", None)
                    if orig_datetime is None:
                        try:
                            orig_datetime = page0.tags.get("DateTime").value
                        except Exception:
                            pass
                    if orig_software is None:
                        try:
                            orig_software = page0.tags.get("Software").value
                        except Exception:
                            pass
                except Exception:
                    pass
                # Read array and axes to preserve layer/channel order
                series = tf.series[0]
                axes = getattr(series, "axes", None)
                arr = None
                try:
                    arr = series.asarray()
                except Exception:
                    arr = None
                # If axes expose channels, reorder to YXC
                if arr is not None and axes:
                    axis_chars = list(axes)
                    # Build slices: crop Y/X, keep C/I/S, squeeze others
                    sl = []
                    for ax in axis_chars:
                        if ax == 'Y':
                            sl.append(slice(y0, y0 + ch))
                        elif ax == 'X':
                            sl.append(slice(x0, x0 + cw))
                        elif ax in ('C', 'I', 'S'):
                            sl.append(slice(None))
                        else:
                            sl.append(0)
                    cropped = arr[tuple(sl)]
                    # Determine channel-like axis
                    chan_ax = 'C' if 'C' in axis_chars else ('I' if 'I' in axis_chars else ('S' if 'S' in axis_chars else None))
                    if chan_ax is not None:
                        y_idx = axis_chars.index('Y')
                        x_idx = axis_chars.index('X')
                        c_idx = axis_chars.index(chan_ax)
                        region_arr = np.transpose(cropped, (y_idx, x_idx, c_idx))
                    else:
                        # No channel axis: reorder to YX and drop others
                        try:
                            y_idx = axis_chars.index('Y')
                            x_idx = axis_chars.index('X')
                            region_arr = np.transpose(cropped, (y_idx, x_idx))
                        except Exception:
                            region_arr = cropped.squeeze()
                
                # Additional fallback: read full array via tifffile.imread when series.asarray fails
                if region_arr is None and arr is None:
                    try:
                        from tifffile import imread as _tf_imread
                        arr_full = _tf_imread(str(in_path))
                        if arr_full is not None:
                            arr = arr_full
                            if hasattr(arr, 'ndim'):
                                if arr.ndim == 2:
                                    region_arr = arr[y0:y0+ch, x0:x0+cw][..., np.newaxis]
                                elif arr.ndim == 3:
                                    # Assume last axis is channels
                                    region_arr = arr[y0:y0+ch, x0:x0+cw, :]
                    except Exception:
                        pass

                # If axes don't include channels, stack pages as channels
                if region_arr is None:
                    try:
                        pages = list(tf.pages)
                        # Filter out reduced-resolution (pyramidal) pages and keep only top-level size
                        target_shape = getattr(page0, 'shape', None)
                        filtered_pages = []
                        for p in pages:
                            try:
                                subf = getattr(p, 'subfiletype', None)
                                is_reduced = (subf == 1) or getattr(p, 'is_reduced', False)
                            except Exception:
                                is_reduced = False
                            if not is_reduced and (target_shape is None or getattr(p, 'shape', None) == target_shape):
                                filtered_pages.append(p)
                        use_pages = filtered_pages if len(filtered_pages) > 0 else pages
                        if len(use_pages) > 1:
                            page_arrays = []
                            for p in use_pages:
                                pa = p.asarray()
                                if pa.ndim == 2:
                                    pa = pa[..., np.newaxis]
                                elif pa.ndim == 3:
                                    # If per-page has samples, keep them (H x W x S)
                                    pass
                                else:
                                    pa = np.squeeze(pa)
                                    if pa.ndim == 2:
                                        pa = pa[..., np.newaxis]
                                page_arrays.append(pa)
                            # Concatenate along channel axis
                            try:
                                arr_pages = np.concatenate(page_arrays, axis=-1)
                            except Exception:
                                # Fallback: stack then reshape to HxWxC if possible
                                arr_pages = np.dstack(page_arrays)
                            # Crop Y/X; keep all stacked channels
                            region_arr = arr_pages[y0:y0+ch, x0:x0+cw, :]
                        elif arr is not None:
                            # Single page: ensure channel axis exists
                            if arr.ndim == 2:
                                arr = arr[..., np.newaxis]
                            region_arr = arr[y0:y0+ch, x0:x0+cw, :]
                    except Exception:
                        region_arr = None
        except Exception:
            region_arr = None
        # Fallback to PIL crop if tifffile failed
        if region_arr is None:
            im = he_input.image_obj
            region = im.crop((x0, y0, x0 + cw, y0 + ch)).convert("RGB")
        # If numeric resolution missing but we have OME physical sizes, derive resolution with INCH fallback
        if (xres is None or yres is None) and (mpp_x is not None and mpp_y is not None):
            fallback_unit = "INCH"
            fx, fy = compute_resolution_from_mpp(mpp_x, mpp_y, fallback_unit)
            if fx is not None and fy is not None:
                xres, yres = fx, fy
                resunit_in = fallback_unit

    # Build pyramids
    if 'region_arr' in locals() and region_arr is not None:
        pyr_arrays = build_pyramid_arrays_nd(region_arr, levels)
        arr0 = pyr_arrays[0]
        axes_meta = 'YXC' if (hasattr(arr0, 'ndim') and arr0.ndim == 3) else 'YX'
    else:
        pyr_arrays = build_pyramid_arrays(region, levels)
        axes_meta = 'YX'
    levels_written = len(pyr_arrays)

    out_ext = in_path.suffix if ext is None else ext
    out_path = in_path.parent / f"{in_path.stem}_{cw}x{ch}{out_ext}"

    base_desc_default = (
        (orig_desc.rstrip("\n") + "\n") if orig_desc else "Aperio Image Library 1.2.0\n"
        + f"ImageWidth = {cw}\nImageHeight = {ch}\n"
        + f"Filename = {in_path.name}\n"
        + "Comment = Center-cropped (tifffile pyramid)"
    )
    
    def _sanitize_ascii(s):
        try:
            if s is None:
                return None
            s = str(s).replace("µ", "um")
            return s.encode("ascii", "ignore").decode("ascii")
        except Exception:
            return None
    
    safe_desc = _sanitize_ascii(orig_desc)
    if safe_desc:
        safe_desc = safe_desc.rstrip("\n") + "\n"
    else:
        safe_desc = base_desc_default
    
    safe_datetime = _sanitize_ascii(orig_datetime)
    safe_software = _sanitize_ascii(orig_software)
    from tifffile import TiffWriter

    # Determine OME axes metadata for first page
    arr0 = pyr_arrays[0]
    axes_meta = 'YXC' if (hasattr(arr0, 'ndim') and arr0.ndim == 3) else 'YX'
    metadata_axes = {'axes': axes_meta}

    with TiffWriter(str(out_path), bigtiff=True) as tif:
        write_ifd(
            tif,
            pyr_arrays[0],
            tile=tile,
            q=jpegQ,
            is_level0=True,
            subifds_remaining=levels_written - 1,
            desc=safe_desc,
            xres=xres,
            yres=yres,
            resunit=resunit_in or None,
            datetime=safe_datetime,
            software=safe_software,
            metadata=metadata_axes,
        )
        for arr in pyr_arrays[1:]:
            write_ifd(
                tif,
                arr,
                tile=tile,
                q=jpegQ,
                is_level0=False,
                subifds_remaining=0,
                desc=None,
                xres=xres,
                yres=yres,
                resunit=resunit_in or None,
                metadata=metadata_axes,
            )

    meta = {
        "levelsWritten": levels_written,
        "tile": tile,
        "jpegQ": jpegQ,
        "xResolution": xres,
        "yResolution": yres,
        "resolutionUnit": resunit_in or None,
        "mppX": mpp_x if mpp_x is not None else None,
        "mppY": mpp_y if mpp_y is not None else None,
        "physicalSizeX": mpp_x if mpp_x is not None else None,
        "physicalSizeY": mpp_y if mpp_y is not None else None,
        "physicalSizeXUnit": "µm" if mpp_x is not None else None,
        "physicalSizeYUnit": "µm" if mpp_y is not None else None,
    }
    return out_path, meta


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--inputPath", type=str, required=True, help="Path to input .svs")
    ap.add_argument("--width", type=int, required=True, help="Crop width (level-0 px)")
    ap.add_argument("--height", type=int, required=True, help="Crop height (level-0 px)")
    ap.add_argument("--levels", type=int, default=6, help="Total pyramid levels to write (>=1)")
    ap.add_argument("--tile", type=int, default=256, help="Tile size (square)")
    ap.add_argument("--jpegQ", type=int, default=90, help="JPEG quality (1–100)")
    ap.add_argument("--ext", type=str, default=None, help="Output extension (e.g., .svs or .tif). Defaults to input file extension.")
    ap.add_argument("--snapToTile", action="store_true", help="Snap crop width/height up to tile multiples")
    args = ap.parse_args()

    in_path = Path(args.inputPath).resolve()
    if not in_path.exists() or in_path.suffix.lower() != ".svs":
        print("Error: input must be an existing .svs file", file=sys.stderr)
        sys.exit(1)

    try:
        out_path, meta = cropImage(
            inputPath=str(in_path),
            width=args.width,
            height=args.height,
            levels=args.levels,
            tile=args.tile,
            jpegQ=args.jpegQ,
            ext=args.ext,
            snapToTile=args.snapToTile,
        )
    except Exception as e:
        print(f"Failed to crop SVS: {e}", file=sys.stderr)
        sys.exit(1)

    print(f"Saved pyramidal BigTIFF: {out_path}")
    print(f"- Levels written: {meta['levelsWritten']}")
    print(f"- Tile: {meta['tile']}x{meta['tile']}, JPEG Q={meta['jpegQ']}")
    if meta["xResolution"] and meta["yResolution"] and meta["resolutionUnit"]:
        print(f"- Resolution ({meta['resolutionUnit']}): XResolution={meta['xResolution']:.2f}, YResolution={meta['yResolution']:.2f}")
    elif meta["resolutionUnit"]:
        print(f"- Resolution unit preserved: {meta['resolutionUnit']} (no numeric resolution available)")
    else:
        print("- No resolution unit found in source; wrote without resolution tags.")

if __name__ == "__main__":
    main()
