from typing import Any, Annotated as A, TypeVar
import pandas as pd
from pydantic import Field as F

from ..api_query import Query
from ..client import PvradarClient
from pvlib.location import Location
from ...modeling.decorators import datasource, pvradar_resource_type
from ...modeling.utils import auto_attr_table, convert_series_unit, rate_to_cumulative, ureg
from ..pvradar_resources import pvradar_resource_annotations, PvradarResourceType, SeriesConfigAttrs as S
from ...modeling.basics import Attrs as P


merra2_series_name_mapping: dict[str, PvradarResourceType | A[Any, Any]] = {
    # ----------------------------------------------------
    # MERRA2 - M2I3NVAER, merra2_aerosol_mixing_table
    #
    'AIRDENS': 'air_density',
    'BCPHILIC': 'particle_mixing_ratio',
    'BCPHOBIC': 'particle_mixing_ratio',
    'DU001': 'particle_mixing_ratio',
    'DU002': 'particle_mixing_ratio',
    'DU003': 'particle_mixing_ratio',
    'DU004': 'particle_mixing_ratio',
    'OCPHILIC': 'particle_mixing_ratio',
    'OCPHOBIC': 'particle_mixing_ratio',
    'SO4': 'particle_mixing_ratio',
    'SS001': 'particle_mixing_ratio',
    'SS002': 'particle_mixing_ratio',
    'SS003': 'particle_mixing_ratio',
    'SS004': 'particle_mixing_ratio',
    'RH': A[float, S(resource_type='relative_humidity', unit='dimensionless')],
    # ----------------------------------------------------
    # MERRA2 - M2T1NXFLX, merra2_surface_flux_table
    #
    'PRECSNO': A[pd.Series, S(resource_type='snowfall_mass_rate', unit='kg/m^2/s', agg='mean', freq='h')],
    'PRECTOT': A[pd.Series, S(resource_type='rainfall_mass_rate', unit='kg/m^2/s', agg='mean', freq='h')],
    'PRECTOTCORR': A[pd.Series, S(resource_type='rainfall_mass_rate', unit='kg/m^2/s', agg='mean', freq='h')],
    # ----------------------------------------------------
    # MERRA2 - M2I1NXASM, merra2_meteo_table
    'T2M': A[float, S(resource_type='air_temperature', unit='degK')],
    'U2M': A[float, S(resource_type='wind_speed', unit='m/s')],
    'V2M': A[float, S(resource_type='wind_speed', unit='m/s')],
    # ----------------------------------------------------
    # MERRA2 - M2T1NXLND, merra2_land_surface_table
    'SNODP': A[float, S(resource_type='snow_depth', unit='m')],
}


def _auto_attr_table(df: pd.DataFrame) -> None:
    if df is None:
        return
    _remove_minutes_inplace(df)
    auto_attr_table(
        df,
        series_name_mapping=merra2_series_name_mapping,
        resource_annotations=pvradar_resource_annotations,
    )
    for name in df:
        df[name].attrs['datasource'] = 'merra2'


def _remove_minutes_inplace(df: pd.DataFrame) -> None:
    if not len(df):
        return
    sample = df.index[0]
    if sample.minute != 0:
        df.index = df.index - pd.Timedelta(minutes=sample.minute)


# when requesting data from DB, we add 59min some tolerance to include data with minutes
# originally MERRA2 has data as 00:30, 01:30 ...
def _add_minute_tolerance(interval: pd.Interval) -> pd.Interval:
    minute = interval.right.minute
    result = pd.Interval(interval.left, interval.right + pd.Timedelta(minutes=59 - minute))
    return result


# ----------------------------------------------------
# MERRA2 tables


@pvradar_resource_type('merra2_surface_flux_table')
@datasource('merra2')
def merra2_surface_flux_table(
    location: Location,
    interval: pd.Interval,
) -> pd.DataFrame:
    """Data extracted from MERRA2 M2T1NXFLX dataset"""
    interval = _add_minute_tolerance(interval)
    query = Query.from_site_environment(location=location, interval=interval)
    query.set_path('datasources/merra2/raw/M2T1NXFLX/csv')
    result = PvradarClient.instance().get_df(query, crop_interval=interval)
    _auto_attr_table(result)
    return result


@pvradar_resource_type('merra2_aerosol_mixing_table')
@datasource('merra2')
def merra2_aerosol_mixing_table(
    location: Location,
    interval: pd.Interval,
) -> pd.DataFrame:
    """
    Data extracted from MERRA2 M2I3NVAER dataset
    details: https://developers.google.com/earth-engine/datasets/catalog/NASA_GSFC_MERRA_aer_nv_2
    """
    interval = _add_minute_tolerance(interval)
    query = Query.from_site_environment(location=location, interval=interval)
    query.set_path('datasources/merra2/raw/M2I3NVAER/csv')
    result = PvradarClient.instance().get_df(query, crop_interval=interval)
    _auto_attr_table(result)
    result.attrs['freq'] = '3h'
    return result


@pvradar_resource_type('merra2_meteo_table')
@datasource('merra2')
def merra2_meteo_table(
    location: Location,
    interval: pd.Interval,
) -> pd.DataFrame:
    """
    Data extracted from M2I1NXASM dataset
    """
    interval = _add_minute_tolerance(interval)
    query = Query.from_site_environment(location=location, interval=interval)
    query.set_path('datasources/merra2/raw/M2I1NXASM/csv')
    result = PvradarClient.instance().get_df(query, crop_interval=interval)
    _auto_attr_table(result)
    return result


@pvradar_resource_type('merra2_land_surface_table')
@datasource('merra2')
def merra2_land_surface_table(
    location: Location,
    interval: pd.Interval,
) -> pd.DataFrame:
    """
    Data extracted from M2T1NXLND dataset
    """
    interval = _add_minute_tolerance(interval)
    query = Query.from_site_environment(location=location, interval=interval)
    query.set_path('datasources/merra2/raw/M2T1NXLND/csv')
    result = PvradarClient.instance().get_df(query, crop_interval=interval)
    _auto_attr_table(result)
    return result


# ----------------------------------------------------
# MERRA2 series (alphabetical order)

SeriesOrDf = TypeVar('T', pd.DataFrame, pd.Series)  # type: ignore


def _merra2_3h_to_1h(df: SeriesOrDf, interval: pd.Interval) -> SeriesOrDf:
    start_datetime = interval.left
    end_datetime = interval.right
    assert isinstance(start_datetime, pd.Timestamp)
    new_index = pd.date_range(start=start_datetime, end=end_datetime, freq='1h')
    df = df.reindex(new_index).interpolate().bfill()
    df.attrs['freq'] = '1h'
    return df


@pvradar_resource_type('air_density', rename=True)
@datasource('merra2')
def merra2_air_density(
    *,
    merra2_aerosol_mixing_table: A[pd.DataFrame, P(resource_type='merra2_aerosol_mixing_table')],
) -> pd.Series:
    return merra2_aerosol_mixing_table['AIRDENS']


@pvradar_resource_type('air_temperature', rename=True)
@datasource('merra2')
def merra2_air_temperature(
    *,
    merra2_meteo_table: A[pd.DataFrame, P(resource_type='merra2_meteo_table')],
) -> pd.Series:
    return convert_series_unit(merra2_meteo_table['T2M'], to_unit='degC')


@pvradar_resource_type('particle_mixing_ratio')
@datasource('merra2')
def merra2_particle_mixing_ratio(
    *,
    merra2_aerosol_mixing_table: A[pd.DataFrame, P(resource_type='merra2_aerosol_mixing_table')],
    particle_name: str,
) -> pd.Series:
    if particle_name not in merra2_aerosol_mixing_table:
        raise ValueError(f'Particle {particle_name} not found in aerosol mixing table')
    return merra2_aerosol_mixing_table[particle_name]


@pvradar_resource_type('particle_volume_concentration')
@datasource('merra2')
def merra2_particle_volume_concentration(
    *,
    merra2_aerosol_mixing_table: A[pd.DataFrame, P(resource_type='merra2_aerosol_mixing_table')],
    interval: pd.Interval,
    particle_name: str,
) -> pd.Series:
    if particle_name not in merra2_aerosol_mixing_table:
        raise ValueError(f'Particle {particle_name} not found in aerosol mixing table')
    mixing_ratio = merra2_aerosol_mixing_table[particle_name]
    air_density = merra2_air_density(merra2_aerosol_mixing_table=merra2_aerosol_mixing_table)
    result = mixing_ratio * air_density
    result = _merra2_3h_to_1h(result, interval)
    result.attrs['unit'] = 'kg/m^3'
    result.attrs['resource_type'] = 'particle_volume_concentration'
    return result


@pvradar_resource_type('pm10_volume_concentration', rename='pm10')
@datasource('merra2')
def merra2_pm10_volume_concentration(
    *,
    merra2_aerosol_mixing_table: A[pd.DataFrame, P(resource_type='merra2_aerosol_mixing_table')],
    interval: pd.Interval,
) -> pd.Series:
    df = merra2_aerosol_mixing_table
    result = (
        1.375 * df['SO4']
        + df['BCPHOBIC']
        + df['BCPHILIC']
        + df['OCPHOBIC']
        + df['OCPHILIC']
        + df['DU001']
        + df['DU002']
        + df['DU003']
        + 0.74 * df['DU004']
        + df['SS001']
        + df['SS002']
        + df['SS003']
        + df['SS004']
    ) * df['AIRDENS']
    result = _merra2_3h_to_1h(result, interval)
    result.attrs['resource_type'] = 'pm10_volume_concentration'
    result.attrs['unit'] = 'kg/m^3'
    return result


@pvradar_resource_type('pm2_5_volume_concentration', rename='pm2_5')
@datasource('merra2')
def merra2_pm2_5_volume_concentration(
    *,
    merra2_aerosol_mixing_table: A[pd.DataFrame, P(resource_type='merra2_aerosol_mixing_table')],
    interval: pd.Interval,
) -> pd.Series:
    df = merra2_aerosol_mixing_table
    result = (
        1.375 * df['SO4']
        + df['BCPHOBIC']
        + df['BCPHILIC']
        + df['OCPHOBIC']
        + df['OCPHILIC']
        + df['DU001']
        + df['DU002']
        + 0.58 * df['DU003']
        + df['SS001']
        + df['SS002']
    ) * df['AIRDENS']
    result = _merra2_3h_to_1h(result, interval)
    result.attrs['resource_type'] = 'pm2_5_volume_concentration'
    result.attrs['unit'] = 'kg/m^3'
    return result


@pvradar_resource_type('rainfall_mass_rate', use_std_unit=True)
@datasource('merra2')
def merra2_rainfall_mass(
    merra2_surface_flux_table: A[pd.DataFrame, P(resource_type='merra2_surface_flux_table')],
) -> pd.Series:
    return merra2_surface_flux_table['PRECTOTCORR'].copy()


@pvradar_resource_type('rainfall_rate', use_std_unit=True, rename=True)
@datasource('merra2')
def merra2_rainfall_rate(
    rainfall_mass_rate: A[pd.Series, P(resource_type='rainfall_mass_rate')],
) -> pd.Series:
    result = rainfall_mass_rate.copy()

    # given that 1 kg/m^2 == 1mm of water,
    # we only need to change the unit
    unit_object = ureg(rainfall_mass_rate.attrs['unit']) / ureg('kg/m^2') * ureg('mm')
    result.attrs['unit'] = str(unit_object)
    result.attrs['resource_type'] = 'rainfall_rate'
    return result


@pvradar_resource_type('rainfall', rename=True)
@datasource('merra2')
def merra2_rainfall(
    rainfall_rate: A[pd.Series, P(resource_type='rainfall_rate')],
) -> pd.Series:
    result = rate_to_cumulative(rainfall_rate, resource_type='rainfall')
    return result


@pvradar_resource_type('relative_humidity', rename=True, use_std_unit=True)
@datasource('merra2')
def merra2_relative_humidity(
    *,
    merra2_aerosol_mixing_table: A[pd.DataFrame, F(), P(resource_type='merra2_aerosol_mixing_table')],
    interval: pd.Interval,
) -> pd.Series:
    result = _merra2_3h_to_1h(merra2_aerosol_mixing_table['RH'], interval)
    return result


@pvradar_resource_type('snow_depth', use_std_unit=True, rename=True)
@datasource('merra2')
def merra2_snow_depth(
    *,
    merra2_land_surface_table: A[pd.DataFrame, F(), P(resource_type='merra2_land_surface_table')],
) -> pd.Series:
    return merra2_land_surface_table['SNODP'].copy()


@pvradar_resource_type('snowfall_mass_rate', use_std_unit=True, rename=True)
@datasource('merra2')
def merra2_snowfall_mass_rate(
    *,
    merra2_surface_flux_table: A[pd.DataFrame, F(), P(resource_type='merra2_surface_flux_table')],
) -> pd.Series:
    return merra2_surface_flux_table['PRECSNO'].copy()


@pvradar_resource_type('snowfall_rate', use_std_unit=True, rename=True)
@datasource('merra2')
def merra2_snowfall_rate(
    *,
    snowfall_mass_rate: A[pd.Series, F(), P(resource_type='snowfall_mass_rate')],
    snow_density_value: A[float, F()] = 100,
) -> pd.Series:
    result = snowfall_mass_rate / snow_density_value
    unit_object = ureg(snowfall_mass_rate.attrs['unit']) / ureg('kg/m^3')
    result.attrs['unit'] = str(unit_object)
    result.attrs['resource_type'] = 'snowfall_rate'
    return result


@pvradar_resource_type('snowfall', use_std_unit=True, rename=True)
@datasource('merra2')
def merra2_snowfall(
    *,
    snowfall_rate: A[pd.Series, F(), P(resource_type='snowfall_rate')],
) -> pd.Series:
    result = rate_to_cumulative(snowfall_rate, resource_type='snowfall')
    return result
