

from datetime import datetime
import numpy as np
import numpy.ma as ma
from netCDF4 import Dataset, date2num, num2date
import pandas as pd
import xarray as xr

class NetCDFHelper():

    @staticmethod
    def get_dimensions(dataset):
        """
        Given a netCDF dataset, find and return the the dimension
        names
        """
        return [dim for dim in dataset.dimensions]

    @staticmethod
    def get_variables(dataset):
        """
        Given a netCDF dataset, find and return all the variables names
        """
        return [var for var in dataset.variables]

    @staticmethod
    def get_parameters(dataset):
        """
        Given an netCDF dataset, find and return the parameter
        (variable) name.
        ERA5 files typically have many parameters (variables). To access the variable, we
        need to know the variable name. Since we download the parameters
        into a single file, we typically have four parameters: latitude, longitude, time
        and the variable value. Latitude, longitude and time are dimensional variables, and
        when the file is queried, their names are returned as both dimensions and variables.
        The remaining variables in the variable list is/are the parameter variables.
        """
        nc_dims = NetCDFHelper.get_dimensions(dataset)
        nc_vars = NetCDFHelper.get_variables(dataset)
        nc_params = []
        for var in nc_vars:
            if var not in nc_dims:
                nc_params.append(var)
        return nc_params

    @staticmethod
    def get_dimension_order_for_parameters(dataset, param ):
        """
        Given a netCDF dataset and a parameter name, get the order of the dimension
        for the variable
        """
        results = []
        for dimension in dataset.variables[param].dimensions:
            # push everything to lower case just in case
            results.append(dimension.lower())
        return results

    @staticmethod
    def get_param_nodata_value(dset, param):
        var = ma.array(dset.variables[param])
        NDV = var.get_fill_value()
        return NDV

    @staticmethod
    def get_variable(dataset, name):
        """
        Given a netCDF dataset and a variable name return:
            variable attributes
            variable values,
            variable type
            variable dimensions
        """
        var_datatype = dataset.variables[name].datatype
        var_dimensions = dataset.variables[name].dimensions
        attribute_names = dataset.variables[name].ncattrs()
        var_attributes = {}

        for attribute in attribute_names:
            var_attributes[attribute] = dataset.variables[name].getncattr(attribute)
        var_values = np.array(dataset.variables[name])
        return var_attributes, var_values, var_datatype, var_dimensions

    @staticmethod
    def copy_dimension(dimensionName, sourceDataSet, destinationDataSet):
        dimensionSource = sourceDataSet.dimensions[dimensionName]
        destinationDataSet.createDimension(dimensionName, dimensionSource.size)
        srcAttributes, srcValue, srcDataType, _ = NetCDFHelper.get_variable(sourceDataSet, dimensionName)
        # create the destination Dimension
        var_dest = destinationDataSet.createVariable(dimensionName, srcDataType, (dimensionName,))

        var_dest[:] = srcValue[:]
        for key, value in srcAttributes.items():
            var_dest.setncattr(key, value)

    @staticmethod
    def update_var_attributes(var, attributes):
        for k, v in attributes.items():
            var.setncattr(k, v)

    @staticmethod
    def time_number_to_datetime(dataset, time_dimension_name = 'time'):
        t_attributes , t_val , _, _ = NetCDFHelper.get_variable(dataset,time_dimension_name)
        return num2date(t_val[:],  t_attributes['units'], calendar=t_attributes['calendar'])

    @staticmethod
    def get_var_xarray(dataset,name):
        var_attributes, var_values, var_datatype, var_dimensions = NetCDFHelper.get_variable(dataset,name)
        coord_list = []
        dim_list = []
        for dim in var_dimensions:
            if dim.lower().startswith('lon'):
                dim_list.append('longitude')
            elif dim.lower().startswith('lat'):
                dim_list.append('latitude')
            else:
                dim_list.append(dim)
            coord_list.append(dim_values)
        Array = xr.DataArray(var_values, coords=coord_list, dims=dim_list)
        return Array

    def change_attributes(attributes, **kwargs):
        """
        Given a dictionary of attributes, update the attributes.
        the 'key' represents the attribute key. the value is the
        replacement value.
        """
        missing_keys = []
        rtn = attributes
        for key, value in kwargs.item():
            if key in attributes:
                rtn[key] = value
            else:
                # assume we cannot change a key that does not exits
                missing_keys.append(key)
        if len(missing_keys) > 0:
            mkeys = ''.join(missing_keys)
            raise ValueError('changeAttributes: missing keys {}'.format(mkeys))
        return rtn




