# src/samudra_ai/utils.py

import json
import numpy as np
import xarray as xr
import datetime
from sklearn.metrics import mean_absolute_error

def standardize_dims(ds: xr.DataArray) -> xr.DataArray:
    """Menyeragamkan nama dimensi ke 'lat' dan 'lon'."""
    var_name = ds.name or "variable"
    ds.name = var_name
    ds_temp = ds.to_dataset()
    rename_dict = {}
    for name in ['latitude', 'y', 'j']:
        if name in ds_temp.dims: rename_dict[name] = 'lat'
    for name in ['longitude', 'x', 'i']:
        if name in ds_temp.dims: rename_dict[name] = 'lon'
    da = ds_temp.rename(rename_dict)[var_name]
    # pastikan urutan time, lat, lon jika ada
    if all(dim in da.dims for dim in ["time", "lat", "lon"]):
        da = da.transpose("time", "lat", "lon")
    return da

def compute_metrics(obs, model):
    """Menghitung metrik statistik antara data observasi dan model."""
    if isinstance(obs, xr.DataArray): obs = obs.values
    if isinstance(model, xr.DataArray): model = model.values
    obs_vals, model_vals = obs.flatten(), model.flatten()
    mask = ~np.isnan(obs_vals) & ~np.isnan(model_vals)
    obs_clean, model_clean = obs_vals[mask], model_vals[mask]
    if len(obs_clean) == 0: return np.nan, np.nan, np.nan, np.nan
    correlation = np.corrcoef(obs_clean, model_clean)[0, 1]
    rmse = np.sqrt(np.mean((model_clean - obs_clean) ** 2))
    bias = np.mean(model_clean - obs_clean)
    mae = mean_absolute_error(obs_clean, model_clean)
    return correlation, rmse, bias, mae

def save_to_netcdf(da: xr.DataArray, path: str, attrs: dict = None):
    """
    Simpan xarray.DataArray ke file NetCDF dengan atribut dan encoding standar.
    """
    da.name = da.name or "variable"
    da.attrs.update(attrs or {})
    da.attrs.setdefault("title", "Hasil Koreksi SAMUDRA-AI")
    da.attrs.setdefault("description", "Data hasil koreksi bias menggunakan CNN-BiLSTM")
    da.attrs.setdefault("history", f"Generated by SAMUDRA-AI on {datetime.datetime.now().isoformat()}")
    
    # Tambahkan encoding agar tidak terlalu besar (opsional)
    encoding = {
        da.name: {
            "zlib": True,
            "complevel": 4,
            "_FillValue": -9999.0,
        }
    }
    da.to_netcdf(path, encoding=encoding, engine="h5netcdf")
    print(f"💾 Data berhasil disimpan ke: {path}")

class NumpyEncoder(json.JSONEncoder):
    """ Kelas helper untuk encoding NumPy array ke JSON. """
    def default(self, obj):
        if isinstance(obj, np.generic):
            return obj.item()
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return super().default(obj)