import polars as pl
import pydantic
from pyproj import CRS, Transformer
from shapely.geometry import Point
from shapely.wkb import dumps as wkb_dumps
from shapely.wkb import loads as wkb_loads

from tacotoolbox.sample.datamodel import SampleExtension

# Soft dependency - only imported when check_antimeridian=True
try:
    import antimeridian

    HAS_ANTIMERIDIAN = True
except ImportError:
    HAS_ANTIMERIDIAN = False


class ISTAC(SampleExtension):
    """
    Irregular SpatioTemporal Asset Catalog (ISTAC) metadata for non-regular geometries.

    For geospatial data that cannot be represented with an affine geotransform:
    - Satellite swaths: CloudSat, CALIPSO, GPM orbital tracks
    - Flight paths: Aircraft or drone trajectories
    - Vector data: Polygons, lines, or points without underlying raster
    - Irregular samplings: Weather stations, buoy arrays, sensor networks

    Unlike STAC (designed for regular rasters with geotransform), ISTAC stores
    the complete geometry directly as WKB binary for arbitrary spatial footprints.

    Notes
    -----
    - Timestamps stored as Parquet native TIMESTAMP type (seconds precision)
    - Accepts epoch seconds (int) which are converted to proper Parquet timestamps
    - Centroid is always in EPSG:4326 regardless of source geometry CRS
    - For regular raster grids, use the STAC extension instead
    - WKB binary format for efficient storage and GeoParquet compatibility
    - time_middle is auto-computed when both start and end times exist
    - Set check_antimeridian=True for Pacific/Polar data (requires: pip install antimeridian)
    """

    crs: str
    geometry: bytes
    time_start: int
    time_end: int | None = None
    time_middle: int | None = None
    centroid: bytes | None = None
    check_antimeridian: bool = False

    @pydantic.model_validator(mode="after")
    def check_times(self) -> "ISTAC":
        """Validate that time_start <= time_end if time_end is provided."""
        if self.time_end is not None and self.time_start > self.time_end:
            raise ValueError(
                f"Invalid temporal interval: time_start ({self.time_start}) "
                f"> time_end ({self.time_end})"
            )
        return self

    @pydantic.model_validator(mode="after")
    def populate_time_middle(self) -> "ISTAC":
        """Auto-populate time_middle when both time_start and time_end exist."""
        if self.time_end is not None and self.time_middle is None:
            self.time_middle = (self.time_start + self.time_end) // 2

        return self

    @pydantic.model_validator(mode="after")
    def populate_centroid(self) -> "ISTAC":
        """
        Auto-compute centroid in EPSG:4326 if not provided.

        If check_antimeridian=True, uses 'antimeridian' package to correctly
        handle geometries crossing ±180° longitude (e.g., Pacific swaths).
        """
        if self.centroid is None:
            # Load geometry from WKB
            geom = wkb_loads(self.geometry)

            # Compute centroid with optional antimeridian handling
            if self.check_antimeridian:
                if not HAS_ANTIMERIDIAN:
                    raise ImportError(
                        "check_antimeridian=True requires the 'antimeridian' package.\n"
                        "Install with: pip install antimeridian\n"
                        "Or set check_antimeridian=False to use fast mode (works for most geometries)."
                    )
                # Use antimeridian-aware centroid calculation
                # Correctly handles geometries crossing ±180° longitude
                centroid_geom = antimeridian.centroid(geom)
            else:
                # FAST PATH (default): Standard shapely centroid
                # Works correctly for 99% of geometries (those not crossing ±180°)
                centroid_geom = geom.centroid

            # Transform to EPSG:4326 if needed
            if self.crs.upper() != "EPSG:4326":
                transformer = Transformer.from_crs(
                    CRS.from_string(self.crs), CRS.from_epsg(4326), always_xy=True
                )
                x, y = transformer.transform(centroid_geom.x, centroid_geom.y)
                centroid_geom = Point(x, y)

            # Store as WKB
            self.centroid = wkb_dumps(centroid_geom)

        return self

    def get_schema(self) -> dict[str, pl.DataType]:
        """Return the expected Polars schema for this extension."""
        return {
            "istac:crs": pl.Utf8(),
            "istac:geometry": pl.Binary(),
            "istac:time_start": pl.Datetime(time_unit="s", time_zone=None),
            "istac:time_end": pl.Datetime(time_unit="s", time_zone=None),
            "istac:time_middle": pl.Datetime(time_unit="s", time_zone=None),
            "istac:centroid": pl.Binary(),
        }

    def _compute(self, sample) -> pl.DataFrame:
        """Actual computation logic - only called when schema_only=False."""
        return pl.DataFrame(
            {
                "istac:crs": [self.crs],
                "istac:geometry": [self.geometry],
                "istac:time_start": [self.time_start],
                "istac:time_end": [self.time_end],
                "istac:time_middle": [self.time_middle],
                "istac:centroid": [self.centroid],
            },
            schema=self.get_schema(),
        )