import ntpath
import numpy as np
import os
import pandas as pd
from osgeo import gdal, gdal_array, osr
import pathlib2
import time

from dselib.dse_gdal.dot_name import get_n_admin_levels

class ClippingException(Exception): pass
class RasterizationException(Exception): pass
class NoDataException(Exception): pass

DEFAULT_BURN_VALUE = 0
DEFAULT_OUTPUT_DIRECTORY = '.'
DEFAULT_NAME_FIELD = 'dot_name'
DEFAULT_DATA_FIELD = 'population'

NEW_NODATA_VALUE = 0
IN_SHAPE_VALUE = 255
OUT_OF_SHAPE_VALUE = 127

BLACK = (0, 0, 0)
WHITE = (255, 255, 255)
RED = (255, 0, 0)

class RasterHelper():
    @staticmethod
    def write_to_png(raster, out_filename):
        """
        Creates a png-format gdal dataset and writes it to the specified filename
        :param raster: a raster to convert (e.g. a tif raster)
        :param out_filename: the png filename to write
        :return: nothing
        """
        result = gdal.Translate(out_filename, raster, format='PNG')
        del result

    @staticmethod
    def clip_raster_with_window(raster, window, format=None):
        """
        Clip a raster with a bounding box (window).
        :param raster: a GDAL raster dataset to be clipped
        :param window: a bounding box to use for clipping the provided raster
        :param format: result will be returned in this format (default: format of provided raster)
        :return: A clipped raster
        """
        format = format or raster.GetDriver().ShortName

        filename = raster.GetDescription()
        fn, ext = os.path.splitext(os.path.basename(filename))
        vmem_name = '/vsimem/' + fn + str(time.time()).replace('.', '') + ext

        clipped_raster = gdal.Translate(vmem_name, raster, projWin=window, format=format)
        if not clipped_raster:
            raise ClippingException('Failed to trim raster to window: %s .' % str(window))
        return clipped_raster

    @staticmethod
    def clip_raster_with_shape(raster, layer, feature_number, band_number=1):
        """
        Returns a numpy MaskedArray of raster_ds (specified band) data that lies within the specified shape. Data is masked
        (invalidated) where values are exactly equal to the nodata value of the input dataset.
        :param raster: a GDAL Dataset (from a raster file)
        :param layer: an OGR Layer (from a shapefile)
        :param feature_number: The feature number (shape) in the given OGR Layer (layer) to clip raster_ds with
        :param band_number: The dataset band to clip with the specified shape
        :return: a numpy MaskedArray
        """
        nodata_value = raster.GetRasterBand(band_number).GetNoDataValue()

        if nodata_value is None:
            raise NoDataException('A nodata value must be set on the selected raster band: %d' % band_number)

        # optimization: only consider points within a tight bounding box
        # convert to UL and LR corner specification
        feature = layer.GetFeature(feature_number)
        minX, maxX, minY, maxY = feature.GetGeometryRef().GetEnvelope()
        window = (minX, maxY, maxX, minY)

        clipped_raster = RasterHelper.clip_raster_with_window(raster, window=window)


        original_array = np.ma.masked_equal(clipped_raster.ReadAsArray(), value=nodata_value)

        # burn zeros to allow the shape to appear in the difference from the original raster
        RasterHelper.burn_shape_to_raster(raster=clipped_raster, layer=layer, feature_number=feature_number,
                             burn_values=[0], band_number=band_number)

        # Difference to retain in-shape values only
        clipped_array = np.ma.masked_equal(clipped_raster.ReadAsArray(), value=nodata_value)
        difference = original_array - clipped_array
        clipped_raster = None

        return difference

    @staticmethod
    def burn_shape_to_raster(raster, layer, feature_number, burn_values=None, band_number=1):
        """
        Rasterizes (burns) the specified feature_number from the given layer to band band_number of the provided raster.
        Values burned are as provided or [0] by default.
        :param raster: a GDAL raster dataset to be burned
        :param layer: the layer object containing a feature (shape) to burn
        :param feature_number: the particular feature (shape) to burn onto the raster band
        :param burn_values: the value(s) to burn. Must be a list of one or more numbers.
        :param band_number: the band_number of the provided raster to burn
        :return: Nothing
        """
        burn_values = burn_values or [DEFAULT_BURN_VALUE]
        try:
            # we only want to use one particular feature
            layer.SetAttributeFilter("FID=%d" % feature_number)
            output = gdal.RasterizeLayer(raster, [band_number], layer, burn_values=burn_values)
        finally:
            # prevent pollution of the provided layer object
            layer.SetAttributeFilter('')

        # Check for error; output is 0 on success
        if output != 0:
            clip = None
            raise RasterizationException('Failed to burn values: %s for feature number: %d from provided shape '
                                         'to raster band: %d' % (burn_values, feature_number, band_number))

    @staticmethod
    def burn_shapes_to_raster(raster, layer, feature_numbers, burn_values, band_number=1):
        """
        Rasterizes (burns) the specified feature_numbers from the given layer to band band_number of the provided raster.
        Values burned are as provided or [0] by default (for each feature/shape)
        :param raster: a GDAL raster dataset to be burned
        :param layer: the layer object containing a features (shapes) to burn
        :param feature_numbers: a list of feature numbers (shapes) to burn onto the raster band
        :param burn_values: the value(s) to burn. Must be a list of 1+ length tuples, one per feature number.
        :param band_number: the band_number of the provided raster to burn
        :return: Nothing
        """
        n_features = len(list(feature_numbers))

        if len(burn_values) != n_features:
            raise ValueError('A burn value list must be provided for each layer (a list of lists).')

        for feature_number, feature_burn_values in zip(feature_numbers, burn_values):
            RasterHelper.burn_shape_to_raster(raster=raster, layer=layer, feature_number=feature_number, burn_values=feature_burn_values,
                                 band_number=band_number)

    @staticmethod
    def create_unassigned_pixel_raster(raster, layer, admin_level, name_field=None,
                                       band_number=1, output_directory=None):
        """
        Creates a 3-value (in shape, out of shape, no data) clone of the provided raster and reports some basic statistics
        regarding pixel counts in these categories. Values are determined by applying the shapes/features of the selected
        admin level of the provided shape file layer to a blank copy of the input raster.
        :param raster: a GDAL raster dataset to create a in/out/nodata pixel map for
        :param layer: the layer object containing shapes to utilize
        :param admin_level: features/shapes of this admin_level will be used
        :param name_field: the feature/shape item field to get admin_level info from
        :param band_number: the band_number of the provided raster to work with
        :param output_directory: the directory where tif and png result rasters will be written
        :return: a pixel status stats dict
        """
        output_directory = output_directory or DEFAULT_OUTPUT_DIRECTORY
        os.makedirs(output_directory, exist_ok=True)
        name_field = name_field or 'dot_name' # ck4, should not have this string literal here

        # setup output filenames
        processing_time = time.strftime('%Y%m%d-%H%M%S')
        output_tif_name = os.path.join(output_directory, 'admin_level_%d_map_%s.tif' % (admin_level, processing_time))
        output_png_name = os.path.splitext(output_tif_name)[0] + '.png'

        # read some key raster details
        x_size = raster.RasterXSize
        y_size = raster.RasterYSize
        geotransform = raster.GetGeoTransform()
        projection = raster.GetProjection()
        band = raster.GetRasterBand(band_number)
        original_nodata_value = band.GetNoDataValue()
        data = band.ReadAsArray()
        del band
        del raster

        # Create a Byte-type destination that is a blank copy of the input raster
        # Byte-type is required for the ColorTable assignment to work/mean anything.
        mem_drv = gdal.GetDriverByName("GTiff")
        raster = mem_drv.Create(output_tif_name, xsize=x_size, ysize=y_size, bands=1, eType=gdal.GDT_Byte)
        raster.SetGeoTransform(geotransform)
        raster.SetProjection(projection)
        del geotransform
        del projection

        # generate masks to distinguish nodata from data pixels
        no_data_mask = (data == original_nodata_value)
        data_mask = (data != original_nodata_value)

        # set initial value for all non-nodata pixels to out-of-shape (red)
        # print('Writing out-of-shape values...')
        np.putmask(data, data_mask, OUT_OF_SHAPE_VALUE)

        # Copy our new, ready-to-burn data band back to the raster
        band = raster.GetRasterBand(band_number)
        band.WriteArray(data)
        del data
        del band

        # discover which features are at the requested admin level
        feature_numbers = []
        for feature_number in range(layer.GetFeatureCount()):
            feature = layer.GetFeature(feature_number)
            feature_name = feature.items()[name_field]
            if get_n_admin_levels(feature_name) == admin_level + 1:
                feature_numbers.append(feature_number)

        # generate burn values for in-shape pixels (white)
        n_features = len(feature_numbers)
        burn_values = [[IN_SHAPE_VALUE]] * n_features

        # burn shapes of the selected admin level onto the raster
        # print('Burning admin level %d features to raster...' % admin_level)
        RasterHelper.burn_shapes_to_raster(raster=raster, layer=layer, feature_numbers=feature_numbers, burn_values=burn_values,
                              band_number=band_number)

        # set nodata value (black) last and copy back to raster
        # print('Writing no data values...')
        band = raster.GetRasterBand(band_number)
        data = band.ReadAsArray()
        np.putmask(data, no_data_mask, NEW_NODATA_VALUE)
        band.WriteArray(data)
        del data

        # print('Generating results...')

        # set the color table for interpreting the existing value scale
        colors = gdal.ColorTable()
        colors.SetColorEntry(NEW_NODATA_VALUE, BLACK)
        colors.SetColorEntry(IN_SHAPE_VALUE, WHITE)
        colors.SetColorEntry(OUT_OF_SHAPE_VALUE, RED)
        band.SetRasterColorTable(colors)
        band.SetRasterColorInterpretation(gdal.GCI_PaletteIndex)
        del band

        # convert to a more manageable format
        RasterHelper.write_to_png(raster=raster, out_filename=output_png_name)

        # compute some raster statistics
        band = raster.GetRasterBand(band_number)
        data = band.ReadAsArray()
        del band
        del raster
        n_in_shape = len(np.where(data == IN_SHAPE_VALUE)[0])
        n_out_of_shape = len(np.where(data == OUT_OF_SHAPE_VALUE)[0])
        n_nodata = len(np.where(data == NEW_NODATA_VALUE)[0])
        n_pixels = n_in_shape + n_out_of_shape + n_nodata
        n_pixels_with_data = n_in_shape + n_out_of_shape
        frac_in_shape = n_in_shape/n_pixels
        frac_out_of_shape = n_out_of_shape/n_pixels
        frac_nodata = n_nodata/n_pixels
        del data
        stats = {
            'all_pixels': {
                'n_no_data': n_nodata,
                'fraction_no_data': frac_nodata,
                'n_in_shape': n_in_shape,
                'fraction_in_shape': frac_in_shape,
                'n_out_of_shape': n_out_of_shape,
                'fraction_out_of_shape': frac_out_of_shape
            },
            'valid_pixels': {
                'n_in_shape': n_in_shape,
                'fraction_in_shape': n_in_shape / n_pixels_with_data,
                'n_out_of_shape': n_out_of_shape,
                'fraction_out_of_shape': n_out_of_shape / n_pixels_with_data
            }
        }
        print('Wrote files:\n%s' % '\n'.join([output_tif_name, output_png_name]))
        return stats


        """
        Creates a 3-value (in shape, out of shape, no data) clone of the provided raster and reports some basic statistics
        regarding pixel counts in these categories. Values are determined by applying the shapes/features of the selected
        admin level of the provided shape file layer to a blank copy of the input raster.
        :param raster: a GDAL raster dataset to create a in/out/nodata pixel map for
        :param layer: the layer object containing shapes to utilize
        :param admin_level: features/shapes of this admin_level will be used
        :param name_field: the feature/shape item field to get admin_level info from
        :param band_number: the band_number of the provided raster to work with
        :param output_directory: the directory where tif and png result rasters will be written
        :return: a pixel status stats dict
        """
    @staticmethod
    def _write_clip(raster, feature_data, feature_name, feature_number):
        # setup output raster file output for inspection
        raster_filename = raster.GetFileList()[0]
        file_root, ext = os.path.splitext(raster_filename)
        raster_outfilename = '%s-clipped-%s-%d%s' % (file_root, feature_name, feature_number, ext)
        raster_outfilename = raster_outfilename.replace(':', '-')
        if os.path.exists(raster_outfilename):
            os.remove(raster_outfilename)
        print('writing file: %s' % raster_outfilename)

        feature_ds = gdal_array.OpenArray(feature_data)
        out_ds = gdal.Translate(raster_outfilename, feature_ds, format=raster.GetDriver().ShortName)
        out_ds = None
        feature_ds = None

    @staticmethod
    def extract_admin_level_data(raster, layer, data_operation, name_field=None, data_field=None, band_number=1,
                                 write_clips=False):
        """
        Clips a raster band with shapes (features) in a shape file layer and aggregates the contained pixel data by
        a provides method (e.g. np.ma.sum). Results are reported in a DataFrame, one row per admin dot_name, so if there
        are 2+ shapes with the same dot_name, all pixels for these shapes are combined before aggregation to a single value.
        :param raster: a GDAL raster dataset to generate admin-aggregated data from
        :param layer: the layer object containing shapes to utilize. Features in this layer must contain items() with
                      a name_field (argument value) entry.
        :param data_operation: The method to use to aggregate pixel data to a single value per name_field value.
        :param name_field: the layer item() that contains feature/shape names (e.g. 'dot_name')
        :param data_field: the column name in the result DataFrame for aggregated pixel data
        :param band_number: the band_number of the provided raster to work with
        :param write_clips: write the individual shape-clipped image files used in aggregation. Useful for debugging.
                            True/False.
        :return: a DataFrame with name_field and data_field columns (defined by method arguments) of aggregated pixel data
        """
        name_field = name_field or DEFAULT_NAME_FIELD
        data_field = data_field or DEFAULT_DATA_FIELD

        # # ck4, temporary, dealing with my poorly constructed downsampled files
        # if raster.GetRasterBand(1).GetNoDataValue() is None:
        #     raster.GetRasterBand(1).SetNoDataValue(float(-1))
        #     print('DEBUG: Set nodata value to: %s' % raster.GetRasterBand(1).GetNoDataValue())
        # else:
        #     print('No data value is: %s' % raster.GetRasterBand(1).GetNoDataValue())

        # first, group features by the naming column ('dot_name') so that we can process names one at a time.
        feature_names = {}
        for feature_number in range(layer.GetFeatureCount()):
            feature = layer.GetFeature(feature_number)
            feature_name = feature.items()[name_field]
            if feature_name not in feature_names:
                feature_names[feature_name] = []
            feature_names[feature_name].append(feature_number)

        # now clip the raster data with the provided shapes, aggregating by feature_name and applying the selected
        # data function to the clipped rasters.
        df = []
        for feature_name, feature_numbers in feature_names.items():
            combined_feature_data = None

            # combine the features using the same name as a single, flattened numpy MaskedArray
            for feature_number in feature_numbers:
                feature_data = RasterHelper.clip_raster_with_shape(raster, layer, feature_number=feature_number, band_number=band_number)
                if write_clips:
                    RasterHelper._write_clip(raster, feature_data, feature_name, feature_number)
                if combined_feature_data is None:
                    combined_feature_data = feature_data
                else:
                    if combined_feature_data.ndim > 1:
                        combined_feature_data = combined_feature_data.flatten()
                    combined_feature_data = np.ma.concatenate([combined_feature_data, feature_data.flatten()])

            value = data_operation(combined_feature_data)

            # construct result dataframe row for this feature
            df.append({name_field: feature_name.lower(), data_field: value})

        # convert result to a DataFrame
        result = pd.DataFrame(df)
        return result

    @staticmethod
    def reformat_raster(raster_in, raster_out, format=None):
        fileFormat = format
        if os.path.exists(raster_in):
            file_base_name = ntpath.basename(raster_in)
            ds = gdal.Open(raster_in)
            out_path = os.path.dirname(os.path.abspath(raster_out))

            pathlib2.Path(out_path).mkdir(parents=True, exist_ok=True)
            if format is None:
                fileFormat = ds.GetDriver().ShortName
            ds = gdal.Translate(raster_out, ds, format=fileFormat)
            del ds # flush the data

    @staticmethod
    def resample_raster(ds, pixelSize, epsgFrom=4326, resampleMethod=gdal.GRA_NearestNeighbour, epsgTo=4326):
        """
        This is sample function to reproject and resample a GDAL dataset from within
        Python. The assumption here is that the user has an understanding of the input file
        project, and can calculate the new pixel scaling factors. For example if the raster is using
        wgs84 projections, and each pixel was 5KM square at the equator (~2.5 min), and the user wanted to
        reproject the raster to 1KM squares (~30 sec), then the scale factor wold be ~0.2.
        It is not uncommon when re-sampeling a raster, that the user may re-project the raster as well.
        It is assumed that the user wants to do something else with the data when done, so the data is written
        to a 'mem' and is not saved
        The procedure is:

        1. Set up the two Spatial Reference systems.
        2. Open the original dataset, and get the geotransform
        3. Calculate bounds of new geotransform by projecting the UL corners
        4. Calculate the number of pixels with the new projection & spacing
        5. Create a new raster dataset
        6. Perform the Resampeling
        """
        # 1. Set up the two spatial Reference Systems
        projTo = osr.SpatialReference()
        projTo.ImportFromEPSG(epsgTo)
        projFrom = osr.SpatialReference()
        projFrom.ImportFromEPSG(epsgFrom)
        tx = osr.CoordinateTransformation(projFrom, projTo)
        # Up to here, all  the projection have been defined, as well as a
        # transformation

        # 2. open original data set
        if ds:
            g = ds

            geo_t = g.GetGeoTransform()
            x_size = g.RasterXSize  # Raster xsize
            y_size = g.RasterYSize  # Raster ysize

            # 3) calculate boundaries of the new dataset in the target projection
            (ulx, uly, ulz) = tx.TransformPoint(geo_t[0], geo_t[3])
            (lrx, lry, lrz) = tx.TransformPoint(geo_t[0] + geo_t[1] * x_size,
                                                geo_t[3] + geo_t[5] * y_size)

            # Now, we create an in-memory raster
            mem_drv = gdal.GetDriverByName('MEM')
            # The size of the raster is given the new projection and pixel spacing
            # Using the values we calculated above. Also, setting it to store one band
            # and to use Float32 data type.
            dest = mem_drv.Create('', int((lrx - ulx) / pixelSize),
                                  int((uly - lry) / pixelSize), 1, gdal.GDT_Float32)
            # Calculate the new geotransform
            new_geo = (ulx, pixelSize, geo_t[2],
                       uly, geo_t[4], -pixelSize)
            # Set the geotransform
            dest.SetGeoTransform(new_geo)
            dest.SetProjection(projTo.ExportToWkt())
            # Perform the projection/resampling
            # Note: resampeling methods are:
            #   GRA_NearestNeighbor - nearestNeighbor
            #   GRA_Bilinear - bilinear
            #   GRA_Cubic - cubic
            #   GRA_CubicSpline - cubicspline
            #   GRA_Lanczos - lancoz
            # may want to add option to set resample method
            res = gdal.ReprojectImage(g, dest,
                                      projFrom.ExportToWkt(), projTo.ExportToWkt(),
                                      resampleMethod)
            return dest
