
import numpy as np
import plotly.graph_objects as go
import numpy as np
from PIL import Image
import plotly.graph_objects as go


def plot_overlay_interactive_light(overlay, title: str = "Contrail Overlay", max_size: int = 1200):
    """
    Lightweight interactive zoom/pan viewer for an overlay image using Plotly.

    Parameters
    ----------
    overlay : np.ndarray or PIL.Image.Image or xarray-like
        RGB image, shape (H, W, 3) or convertible to that.
    title : str
        Figure title.
    max_size : int
        Maximum size (in pixels) of the longest image dimension after downsampling.
    """
    # --- 1. Convert to NumPy array ---
    if isinstance(overlay, Image.Image):
        img = np.array(overlay)
    else:
        img = np.array(overlay)  # handles np.ndarray, xarray.DataArray, etc.

    # Ensure it's 3-channel
    if img.ndim == 2:  # grayscale → RGB
        img = np.stack([img] * 3, axis=-1)
    elif img.ndim == 3 and img.shape[2] > 3:
        img = img[..., :3]

    h, w = img.shape[:2]

    # --- 2. Downsample if too large (to keep Plotly light) ---
    max_dim = max(h, w)
    if max_dim > max_size:
        scale = max_size / max_dim
        new_w = int(w * scale)
        new_h = int(h * scale)
        img_pil = Image.fromarray(img)
        img_pil = img_pil.resize((new_w, new_h), resample=Image.BILINEAR)
        img = np.array(img_pil)
        h, w = img.shape[:2]

    # --- 3. Build Plotly figure ---
    fig = go.Figure(data=[go.Image(z=img)])

    fig.update_layout(
        title=dict(
            text=f"<b>{title}</b>",
            x=0.5,
            xanchor="center",
            font=dict(size=22, color="white"),
        ),
        width=min(900, w + 100),
        height=min(900, h + 100),
        margin=dict(l=0, r=0, t=60, b=0),
        paper_bgcolor="black",
        plot_bgcolor="black",
        dragmode="pan",
    )

    fig.update_xaxes(visible=False)
    fig.update_yaxes(visible=False, scaleanchor="x")

    fig.show(config={"scrollZoom": True})
import numpy as np
import plotly.graph_objects as go

def plot_contrails_plotly_geo_dark(
    geojson_fc,
    title="Contrail Detection",
    subtitle="Interactive polygon view",
    max_points_per_polygon: int = 300,
):
    fig = go.Figure()
    all_lons = []
    all_lats = []

    for feature in geojson_fc["features"]:
        geom = feature.get("geometry", {})
        if geom.get("type") != "Polygon":
            continue

        rings = geom.get("coordinates", [])
        if not rings:
            continue

        shell = rings[0]
        if len(shell) < 4:
            continue

        if len(shell) > max_points_per_polygon:
            idx = np.linspace(0, len(shell) - 1, max_points_per_polygon, dtype=int)
            shell = [shell[i] for i in idx]

        lons, lats = zip(*shell)
        all_lons.extend(lons)
        all_lats.extend(lats)

        feature_id = feature.get("properties", {}).get("id", "contrail")

        fig.add_trace(
            go.Scattergeo(
                lon=lons,
                lat=lats,
                mode="lines",          # 👈 only lines
                # fill=None or "none"   # 👈 no fill
                line=dict(width=1),
                name=str(feature_id),
                hoverinfo="text",
                text=str(feature_id),
            )
        )

    if not all_lons or not all_lats:
        raise ValueError("No valid polygon coordinates to plot.")

    min_lon, max_lon = min(all_lons), max(all_lons)
    min_lat, max_lat = min(all_lats), max(all_lats)
    pad_lon = max(1.0, 0.05 * (max_lon - min_lon))
    pad_lat = max(1.0, 0.05 * (max_lat - min_lat))

    fig.update_geos(
        projection_type="natural earth",  # or "equirectangular" / "mercator"
        lonaxis_range=[min_lon - pad_lon, max_lon + pad_lon],
        lataxis_range=[min_lat - pad_lat, max_lat + pad_lat],
        showcoastlines=True,
        showcountries=True,
    )

    fig.update_layout(
        paper_bgcolor="rgba(5, 5, 15, 1.0)",
        plot_bgcolor="rgba(5, 5, 15, 1.0)",
        title=dict(
            text=(
                f"<b>{title}</b>"
                f"<br><span style='font-size:13px; color:#AAAAAA;'>{subtitle}</span>"
            ),
            x=0.5,
            xanchor="center",
        ),
        height=720,
        margin=dict(l=10, r=10, t=80, b=10),
    )

    return fig
