import sys, os
import numpy as np
import pandas as pd
import xarray as xr

from enum import Enum

from abc import ABC, abstractmethod

class calcType (Enum):
    average = 'average'
    total = 'total'
    kelvin2c = 'kelvin2c'
    f2c = 'f2c'
    meter2mm = 'meter2mm'
    firstin = 'firstin'
    lastin = 'lastin'

    def __str__(self):
        return self.value

    @staticmethod
    def from_string(s):
        try:
            return calcType[s]
        except KeyError:
            raise ValueError('Invalid calcType: ' + s)

    @staticmethod
    def all(exclude=[]):
        all_variables = [v for v in calcType.__dict__ if not v.startswith('__') and v != __name__ and v not in exclude]

        return all_variables

class RasterCalculator(ABC):
    """
    RasterCalculator is the base abstract class for preforming various raster calculations.
    RasterCalculators preform different operation on Rasters, such as averaging, results array, unit conversion... In many cases we need to clip the values of a raster to a range, for example rain data is always > 0. This is provided by the optional min and max values. If the min or max value is 'None' no clipping is preformed.
    The calculate function is an abstract functions, which is overridden by the subclass.

    The class provides two abstract calculate functions:
        calculate(*Arrays)
        calculate_indx(Array, axis=0)
    """

    @staticmethod
    def calc_factory(calc_type: calcType):
        if calcType.from_string(calc_type) == calcType.average:
            return RasterAggregateAverage()
        if calcType.from_string(calc_type) == calcType.total:
            return RasterAggregateTotals()
        if calcType.from_string(calc_type) == calcType.kelvin2c:
            return RasterKelvinToC()
        if calcType.from_string(calc_type) == calcType.f2c:
            return RasterFToC()
        if calcType.from_string(calc_type)== calcType.meter2mm:
            return RasterMetersTomm()
        if calcType.from_string(calc_type) == calcType.firstin:
            return RasterFirstIn()
        if calcType.from_string(calc_type) == calcType.lastin:
            return RasterLastIn()
        return None

    @abstractmethod
    def calculate_index(self, values, indices, clip_min, clip_max):
        pass
    @abstractmethod
    def calculate(self, *args):
        pass

class RasterAggregateTotals(RasterCalculator):
    """
    RasterAggragagateresults array aggregates multiple  arrays by summing
    the arrays. Returns an array with the same number of dimensions as
    the original array.
    Usage:
        calculate(array1, array2, ...)
            returns a numpy array which is the sum of all the user supplied numpy arrays
        calculate(array, axis=<axis>)
            returns a array of the same rank as the original, whose values are the sum of all values along an axis. The default axis is axis 0

    """
    def __init__(self):
        super().__init__()

    def calculate(self, *Arrays, clip_min=None, clip_max=None):
        results = None
        if len(Arrays) != 0:
            results = np.zeros_like(Arrays[0])
            for Array in Arrays:
                if clip_min is None and clip_max is None:
                    results += Array
                else:
                    results += np.clip(Array, a_min=clip_min, a_max=clip_max)
        return results

    def calculate_index(self, Array, indices, clip_min = None, clip_max = None):
        firstTime = True
        for index in indices:
            if firstTime:
                if clip_min is None and clip_max is None:
                    results = Array[index, :, :]
                else:
                    results = np.clip(Array[index, :, :], a_min=clip_min, a_max=clip_max)
                firstTime = False
            else:
                if clip_min is None and clip_max is None:
                    results += Array[index, :, :]
                else:
                   results += np.clip(Array[index, :, :], a_min=clip_min, a_max=clip_max)
        return results


class RasterAggregateAverage(RasterCalculator):
    """
    Given a set of rasters, return a raster that is the average values for the set
    Methods:
        calculate(Array1, Array1, ...) -> returns (Array1 + Array2 + ...)/N, where N is the number of arrays passed in
        calculate(Array, axis) -> returns average array of the same rank as Array, with shape (1,...)
    """
    def __init__(self):
        super().__init__()


    def calculate(self,*Arrays, clip_min=None, clip_max=None):
        results = None
        if len(Arrays) != 0:
            results = np.zeros_like(Arrays[0])
            for Array in Arrays:
                if clip_min is None and clip_max is None:
                    results += Array
                else:
                    results += np.clip(Array, a_min=clip_min, a_max=clip_max)
        if len(Arrays) != 0:
            results /= len(Arrays)
        return results

    def calculate_index(self, Array, indices, clip_min=None, clip_max = None):
        calculator = RasterAggregateTotals()
        # numpy.average reduces the the array by dimensions defined by axis
        results = calculator.calculate_index(Array, indices, clip_min=clip_min, clip_max=clip_max)
        results = results/len(indices)

        return results


class RasterKelvinToC(RasterCalculator):
    """
    Convert a raster with temperatures in degrees Kelvin to degree Celsius
    Methods:
        calculate -> converts a single raster
        calculate_index -> converts a series of rasters, using an axis as reference. Default axis is 0
    """
    def __init__(self):
        super().__init__()

    def calculate_index(self, Array, indices,clip_min = None, clip_max = None):
        results = Array -273.15
        # note axis is ignored, supplied for consistency
        if clip_min is not None or clip_max is not None:
            results = np.clip(results , a_min=clip_min, a_max=clip_max)
        return results

    def calculate(self, Array, clip_min= None, clip_max=None):
        if clip_min is None and clip_max is None:
            results = Array - 273.15
        else:
            results = np.clip((Array - 273.15) , a_min=clip_min, a_max=clip_max)
        return results

class RasterMetersTomm(RasterCalculator):
    """
    Convert a raster in meters to a millimeters (mm)
    Methods:
        calculate -> converts a single raster
        calculate_index -> converts a series of rasters, using an axis as reference. Default axis is 0
    """
    def __init__(self):
        super().__init__()

    def calculate_index(self, Array, indices, clip_min= None, clip_max=None):
        results = Array * 1000.00
        # note axis is ignored, supplied for consistency
        if clip_min is not None or clip_max is not None:
            results = np.clip(results , a_min=clip_min, a_max=clip_max)
        return results

    def calculate(self, Array, clip_min= None, clip_max=None):
        results = Array * 1000.00
        if clip_min is not None or clip_max is not None:
            results = np.clip(results , a_min=clip_min, a_max=clip_max)
        return results

class RasterFToC(RasterCalculator):
    """
    Convert a raster with temperatures in degrees Fahrenheit to degree Celsius
    Methods:
        calculate -> converts a single raster
        calculate_index -> converts a series of rasters, using an axis as reference. Default axis is 0
    """
    def __init__(self):
        super().__init__()
        self.factor = 5.0/9.0

    def calculate_index(self, Array, indices, clip_min= None, clip_max=None):
        # note axis is ignored, supplied for consistency
        results = (Array - 32.0) * self.factor
        if clip_min is not None or clip_max is not None:
            results = np.clip(results , a_min=clip_min, a_max=clip_max)
        return results
    def calculate(self, Array, clip_min= None, clip_max=None):
        results = (Array - 32.0) * self.factor
        if clip_min is not None or clip_max is not None:
            results = np.clip(results , a_min=clip_min, a_max=clip_max)
        return results

class RasterFirstIn(RasterCalculator):
    """
    Given a raster of time series, return the array for the first time period
    Methods:
        calculate -> returns the first array array[0]
        calculate_index -> returns the first array, using an axis as reference. Default axis is 0
                           a[0:n:m]
    """
    def __init__(self):
        super().__init__()

    def calculate(self,*Arrays, clip_min=None, clip_max=None):
        results = None
        if len(Arrays) != 0:
            if clip_min is None and clip_max is None:
                results = Arrays[0]
            else:
                results = np.clip(Arrays[0], a_min=clip_min, a_max=clip_max)
        return results

    def calculate_index(self, Array, indices, clip_min=None, clip_max = None):
        results = None
        if clip_min is None and clip_max is None:
            results = Array[0, :, :]
        else:
            results = np.clip(Array[0, :, :], a_min=clip_min, a_max=clip_max)
        return results

class RasterLastIn(RasterCalculator):
    """
    Given a raster of time series, return the array for the last time period
    Methods:
        calculate -> returns the last array array[-1]
        calculate_index -> returns the last array, using an axis as reference. Default axis is 0
                           a[-1:n:m]
    """
    def __init__(self):
        super().__init__()

    def calculate(self,*Arrays, clip_min=None, clip_max=None):
        results = None
        if len(Arrays) != 0:
            if clip_min is None and clip_max is None:
                results = Arrays[-1]
            else:
                results = np.clip(Arrays[-1], a_min=clip_min, a_max=clip_max)
        return results

    def calculate_index(self, Array, indices, clip_min=None, clip_max = None):
        resutls = None
        if clip_min is None and clip_max is None:
            results = Array[-1, :, :]
        else:
            results = np.clip(Array[-1, :, :], a_min=clip_min, a_max=clip_max)
        return results


