import numpy as np
from whitebox_workflows import Raster, WbEnvironment

class IUH:
    def __init__(self, dem_raster:Raster, slope_raster:Raster, flow_dir_raster:Raster, flow_acc_raster:Raster, manning_raster:Raster, stream_network_raster:Raster):
        self.dem_raster = dem_raster
        self.slope_raster = slope_raster
        self.flow_dir_raster = flow_dir_raster
        self.flow_acc_raster = flow_acc_raster
        self.manning_raster = manning_raster
        self.stream_network_raster = stream_network_raster

    @staticmethod
    def calculate_travel_time_average_standard_deviation(cell_area_km:float,                                                          
                                                         slope:float, 
                                                         flow_acc:float, 
                                                         manning:float, 
                                                         radiusA:float, 
                                                         radiusB:float, 
                                                         maxV:float, 
                                                         minV:float)->tuple[float, float]:
        """
        radiusA, radiusB is from subarea table in parameter.db3
        maxV, minV is from discharge table in parameter.db3
        """
        cont1 = 2.0 / 3.0
        cont2 = 5.0 / 3.0
            
        slope = max(slope / 100.0, 0.001)

        # get radius, p35 2.28
        radius = radiusA * pow(flow_acc * cell_area_km, radiusB)

        # get velocity
        manning = max(manning, 0.001)
        v = 1.0 / manning * pow(radius, cont1) * pow(slope, 0.5)
        v = min(v, maxV)
        v = max(v, minV)

        # get celerity, p34 2.26
        c = cont2 * v

        # get d, p35, 2.27
        d = v * radius / 2.0 / slope

        # get t and delta for each cell
        t = 1.0 / c
        delta = 2.0 * d / pow(c, 3.0)

        return (t,delta)

    def generate_travel_time_average_standard_deviation(self, radiusA:float,radiusB:float,maxV:float,minV:float)->tuple[Raster,Raster]:
        """
        Repalce C++ function IUHWaterShed::getMeanAndStandardDeviation

        The original function use flowpath as input which is generated by WhiteBox UI function: HDF5WritingUtil.createFlowPath

        The flow path is just the cell index of downstream cell based on flow direction. There is no need to generate that as we could 
        use flow direction raster directly to find the downstream cell like other tools. 
        """

        wbe = WbEnvironment()
        
        t_raster = wbe.new_raster(self.dem_raster.configs)
        delta_raster = wbe.new_raster(self.dem_raster.configs)

        flag_s_raster = wbe.new_raster(self.dem_raster.configs)
        t_s_raster = wbe.new_raster(self.dem_raster.configs)
        delta_s_raster = wbe.new_raster(self.dem_raster.configs)

        rows = self.dem_raster.configs.rows
        cols = self.dem_raster.configs.columns
        noData = self.dem_raster.configs.nodata
        cell_area_km = self.dem_raster.configs.resolution_x * self.dem_raster.configs.resolution_y / 1000000.0

        LnOf2 = 0.693147180559945
        dX = [1, 1, 1, 0, -1, -1, -1, 0]
        dY = [-1, 0, 1, 1, 1, 0, -1, -1]

        #calculate t and standard deviation
        for row in range(rows):
            for col in range(cols):
                if self.dem_raster[row,col] == noData:
                    continue
                
                t_delta = IUH.calculate_travel_time_average_standard_deviation(cell_area_km, 
                                                                     self.slope_raster[row,col],
                                                                     self.flow_acc_raster[row, col],
                                                                     self.manning_raster[row, col], 
                                                                     radiusA, 
                                                                     radiusB, 
                                                                     maxV, 
                                                                     minV)
                t_raster[row,col] = t_delta[0]
                delta_raster[row,col] = t_delta[1]

        #initilzating
        for row in range(rows):
            for col in range(cols):
                if self.dem_raster[row,col] == noData:
                    continue

                flag_s_raster[row,col] = 0
                t_s_raster[row,col] = 0
                delta_s_raster[row,col] = 0
                

        for row in range(rows):
            for col in range(cols):
                #skip the no data
                if self.dem_raster[row,col] == noData:
                    continue
                
                #only consider the most upstream ones
                if int(self.flow_acc_raster[row,col]) > 1:
                    continue

                t_s_raster[row, col] = t_raster[row, col]
                delta_s_raster[row, col] = delta_raster[row, col]

                x = col
                y = row
                flowDir = self.flow_dir_raster[row, col]
                while flowDir > 0 and self.stream_network_raster[row, col] < 1:
                    c = int(np.log(flowDir) / LnOf2)
                    x += dX[c]
                    y += dY[c]

                    if flag_s_raster[y, x] == 1:
                        t_s_raster[row, col] += t_s_raster[y, x]
                        delta_s_raster[row, col] += delta_s_raster[y, x]
                        break
                    else:
                        t_s_raster[row, col] += t_raster[y, x]
                        delta_s_raster[row, col] += delta_raster[y, x]
                        flowDir = self.flow_dir_raster[y, x]

                flag_s_raster[row, col] = 1

                t_down = t_s_raster[row, col] - t_raster[row, col]
                delta_down = delta_s_raster[row, col] - delta_raster[row, col]


                x = col
                y = row
                flowDir = self.flow_dir_raster[row, col]
                while flowDir > 0 and self.stream_network_raster[row, col] < 1:
                    c = int(np.log(flowDir) / LnOf2)
                    x += dX[c]
                    y += dY[c]
                
                    if flag_s_raster[y, x] == 1:
                        break

                    t_s_raster[y, x] = t_down
                    delta_s_raster[y, x] = delta_down
                    flag_s_raster[y, x] = 1

                    t_down = t_s_raster[y, x] - t_raster[y, x]
                    delta_down = delta_s_raster[y, x] - delta_raster[y, x]

                    flowDir = self.flow_dir_raster[y, x]

        # s -> hour
        cont3 = self.dem_raster.configs.resolution_x / 3600.0
        cont4 = pow(float(self.dem_raster.configs.resolution_x), 0.5) / 3600.0

        for row in range(rows):
            for col in range(cols):
                if self.dem_raster[row,col] == noData:
                    continue
                t_s_raster[row, col] = t_s_raster[row, col] * cont3

                if delta_s_raster[row, col] < 0:
                    delta_s_raster[row, col] = 0
                else:
                    delta_s_raster[row, col] = pow(delta_s_raster[row, col], 0.5) * cont4

        return (t_s_raster, delta_s_raster)