import math

import ee
import numpy as np

from . import tseb_utils
from . import utils

def tseb_invert(
        et_alexi,
        t_air,
        t_rad,
        t_air0,
        ea,
        u,
        p,
        z,
        rs_1,
        rs24,
        vza,
        aleafv,
        aleafn,
        aleafl,
        adeadv,
        adeadn,
        adeadl,
        albedo,
        ndvi,
        lai,
        clump,
        leaf_width,
        hc_min,
        hc_max,
        datetime,
        lon=None,
        lat=None,
        a_pt_in=1.32,
        stabil_iter=None,
        albedo_iter=10,
):
    """Priestley-Taylor TSEB

    Calculates the Priestley Taylor TSEB fluxes using a single observation of
    composite radiometric temperature and using resistances in series.

    Returns
    -------
    ta : ee.Image
        Air temperature [K?].

    """
    mask = lai.double().multiply(0).rename(['mask'])

    # ************************************************************************
    # Apply met bands directly to Landsat image
    # CGM - This can probably be removed if Rs is resampled/smoothed in disalexi.py
    rs_1 = mask.add(rs_1).rename(['rs'])
    rs24 = mask.add(rs24).rename(['rs'])

    # Estimate total daily latent heat (not in tseb_pt())
    le_tot = et_alexi.expression(
        '(et_alexi * 2.45 / (Rsd * 0.0864 / 24)) * Rs', {'et_alexi': et_alexi, 'Rsd': rs24, 'Rs': rs_1}
    )

    # ************************************************************************
    if lat is None:
        lat = ee.Image.pixelLonLat().select(['latitude'])
    if lon is None:
        lon = ee.Image.pixelLonLat().select(['longitude'])

    zs = tseb_utils.solar_zenith(
        datetime=datetime,
        lon=lon.multiply(math.pi / 180),
        lat=lat.multiply(math.pi / 180),
    )

    # ************************************************************************
    # Correct Clumping Factor
    f_green = 1.0

    # LAI for leaf spherical distribution
    f = lai.multiply(clump)

    # Fraction cover at nadir (view=0)
    fc = f.multiply(-0.5).exp().multiply(-1).add(1.0).clamp(0.01, 0.9)

    # Compute canopy height and roughness parameters
    hc = hc_max.subtract(hc_min).multiply(fc).add(hc_min)

    # LAI relative to canopy projection only
    lai_c = lai.divide(fc)

    # Houborg modification (according to Anderson et al. 2005)
    fc_q = et_alexi.expression('1 - (exp(-0.5 * f / cos(vza)))', {'f': f, 'vza': vza}).clamp(0.05, 0.90)

    # Brutsaert (1982)
    z0m = hc.multiply(0.123)
    # CGM - add(0) is to mimic numpy copy, check if needed
    z0h = z0m.add(0)
    d0 = hc.multiply(2.0 / 3.0)

    # Correction of roughness parameters for water bodies
    # (NDVI < 0 and albedo < 0.05)
    water_mask = ndvi.lt(0).And(albedo.lt(0.05))
    d0 = d0.where(water_mask, 0.00001)
    z0m = z0m.where(water_mask, 0.00035)
    z0h = z0h.where(water_mask, 0.00035)

    # Check to avoid division by 0 in the next computations
    z0h = z0h.where(z0h.eq(0), 0.001)
    z0m = z0m.where(z0m.eq(0), 0.01)

    z_u = 30.0
    z_t = 30.0

    # Modify z0m when using it at alexi scale (not in tseb_pt())
    z0m = et_alexi.expression(
        '1.0 / ((log((z_U - D0) / z0m)) * (log((z_U - D0) / z0m)))',
        {'z_U': z_u, 'D0': d0, 'z0m': z0m}
    )
    # Further for z0m
    z0m = et_alexi.expression('(z_U - D0) / (exp(1.0 / sqrt(z0m)))', {'z_U': z_u, 'D0': d0, 'z0m': z0m})

    # Redefine z0h
    z0h = z0m.add(0)

    # Parameters for In-Canopy Wind Speed Extinction
    leaf = et_alexi.expression(
        '(0.28 * (f ** (0.66667)) * (hc ** (0.33333)) * (leaf_width ** (-0.33333)))',
        {'f': f, 'hc': hc, 'leaf_width': leaf_width}
    )
    leaf_c = et_alexi.expression(
        '(0.28 * (lai_c ** (0.66667)) * (hc ** (0.33333)) * (leaf_width ** (-0.33333)))',
        {'lai_c': lai_c, 'hc': hc, 'leaf_width': leaf_width}
    )
    leaf_s = et_alexi.expression(
        '(0.28 * (0.1 ** (0.66667)) * (hc ** (0.33333)) * (leaf_width ** (-0.33333)))',
        {'hc': hc, 'leaf_width': leaf_width}
    )

    # ************************************************************************
    # Atmospheric Parameters
    # Saturation vapour pressure [kPa] (FAO56 3-8)
    # Yun modified to use METEO air temperature
    e_s0 = t_air0.expression(
        '0.6108 * exp((17.27 * (t_air - 273.16)) / ((t_air - 273.16) + 237.3))', {'t_air': t_air0}
    )
    vpd = e_s0.subtract(ea)

    # Saturation vapor pressure [kpa] using iterated air temperature
    e_s = t_air.expression(
        '0.6108 * exp((17.27 * (t_air - 273.16)) / ((t_air - 273.16) + 237.3))', {'t_air': t_air}
    )

    # Slope of the saturation vapor pressure [kPa] (FAO56 3-9)
    Ss = t_air.subtract(273.16).add(237.3).pow(-2).multiply(e_s).multiply(4098)

    # Latent heat of vaporization (~2.45 at 20 C) [MJ kg-1] (FAO56 3-1)
    # lambda1 = (2.501 - (2.361e-3 * (t_air - 273.16)))
    lambda1 = t_air.subtract(273.16).multiply(2.361e-3).multiply(-1).add(2.501)

    # Psychrometric constant [kPa C-1] (FAO56 3-10)
    gamma = p.multiply(1.615E-3).divide(lambda1)

    # ************************************************************************
    a_pt = mask.add(a_pt_in)
    a_pt = (
        a_pt.add(vpd.subtract(2.0).multiply(0.4))
        .max(a_pt_in).min(2.5)
        .rename('a_pt')
    )

    if stabil_iter is None:
        # Compute the number of stability iterations dynamically
        a_pt_max = ee.Number(
            a_pt.reduceRegion(reducer=ee.Reducer.max(), scale=4000, maxPixels=1E10).get('a_pt')
        )
        stabil_iter = a_pt_max.divide(0.05).ceil().max(25).min(40)

    # ************************************************************************
    rs_c, rs_s, albedo_c, albedo_s, taudl, tausolar = tseb_utils.albedo_separation(
        albedo, rs_1, f, fc, aleafv, aleafn, aleafl, adeadv, adeadn, adeadl, zs, albedo_iter
    )

    e_atm = tseb_utils.emissivity(t_air)

    # Density of air? (kg m-3)
    # CGM - Intentionally making this a function of t_air instead of z
    r_air = (
        t_air.subtract(z.multiply(0.0065)).divide(t_air).pow(5.26).divide(t_air)
        .multiply(101.3 / 1.01 / 0.287)
    )
    cp = 1004.16

    # Assume neutral conditions on first iteration (use t_air for Ts and Tc)
    u_attr = tseb_utils.compute_u_attr(u=u, d0=d0, z0m=z0m, z_u=z_u, fm=0)
    r_ah = tseb_utils.compute_r_ah(u_attr=u_attr, d0=d0, z0h=z0h, z_t=z_t, fh=0)
    r_s = tseb_utils.compute_r_s(
        u_attr=u_attr, t_s=t_air, t_c=t_air, hc=hc, f=lai, d0=d0, z0m=z0m, leaf=leaf, leaf_s=leaf_s, fm_h=0
    )
    r_x = tseb_utils.compute_r_x(
        u_attr=u_attr, hc=hc, f=lai, d0=d0, z0m=z0m, xl=leaf_width, leaf_c=leaf_c, fm_h=0
    )

    t_c = t_air.multiply(1)
    # Modified from tseb_pt() approach
    t_s = et_alexi.expression(
        '(((t_rad - 273.16) - (fc_q * (t_c - 273.16))) / (1 - fc_q)) + 273.16',
        {'t_rad': t_rad, 't_c': t_c, 'fc_q': fc_q}
    )

    # CGM - Initialize to match t_air shape
    ef_s = t_air.multiply(0)

    import pyTSEB.TSEB as TSEB
    # # Estimate atmospheric emissivity from vapour pressure (mb) and air Temperature (K)
    # emisAtm = TSEB.rad.calc_emiss_atm(ea, t_air)
    # Lsky = emisAtm * TSEB.met.calc_stephan_boltzmann(t_air) # in W m-2
    # # Fraction of diffuse and PAR/NIR radiation from shortwave irradiance
    # # (W m-2, solar zenith angle, atmospheric pressure and precipitable water vapour )
    # Sdn = utils.point_image_value(rs_1, xy=[-106.60, 39.72])
    # sza = utils.point_image_value(zs, xy=[-106.60, 39.72])
    # difvis, difnir, fvis, fnir=TSEB.rad.calc_difuse_ratio(
    #     Sdn=Sdn, sza=sza, press=p, Wv=1
    # )
    # # Broadband diffuse fraction
    # Skyl = difvis * fvis + difnir * fnir
    # Sdn_dir = Sdn * (1.0 - Skyl)
    # Sdn_dif = Sdn * Skyl

    # # TODO: Check units
    # output = TSEB.TSEB_PT(
    #     Tr_K=np.array([utils.point_image_value(t_rad, xy=[-106.60, 39.72])], dtype=np.float32),
    #     vza=float(utils.point_image_value(vza, xy=[-106.60, 39.72])),
    #     # TODO: Which air temperature?
    #     T_A_K=float(utils.point_image_value(t_air, xy=[-106.60, 39.72])),
    #     #T_A_K=float(utils.point_image_value(t_air0, xy=[-106.60, 39.72])),
    #     u=float(utils.point_image_value(u, xy=[-106.60, 39.72])),
    #     ea=float(utils.point_image_value(ea, xy=[-106.60, 39.72])),
    #     p=float(utils.point_image_value(p, xy=[-106.60, 39.72])),
    #     Sn_C=float(utils.point_image_value(rs_c, xy=[-106.60, 39.72])),
    #     Sn_S=float(utils.point_image_value(rs_s, xy=[-106.60, 39.72])),
    #     # TODO: Check if this is correct, copied from
    #     L_dn=float(utils.point_image_value(t_air.pow(4).multiply(5.67E-8).multiply(e_atm), xy=[-106.60, 39.72])),
    #     # TODO: LAI needs to be "effective" LAI, not normal LAI
    #     LAI=float(utils.point_image_value(f, xy=[-106.60, 39.72])),
    #     #LAI=float(utils.point_image_value(lai_c, xy=[-106.60, 39.72])),
    #     #LAI=float(utils.point_image_value(lai, xy=[-106.60, 39.72])),
    #     h_C=float(utils.point_image_value(hc, xy=[-106.60, 39.72])),
    #     emis_C=0.99,
    #     emis_S=0.94,
    #     z_0M=float(utils.point_image_value(z0m, xy=[-106.60, 39.72])),
    #     d_0=float(utils.point_image_value(d0, xy=[-106.60, 39.72])),
    #     z_u=z_u,
    #     z_T=z_t,
    #     # leaf_width=0.1,
    #     # z0_soil=0.01,
    #     # alpha_PT=1.26,
    #     # x_LAD=1,
    #     # f_c=1.0,
    #     # f_g=1.0,
    #     # w_C=1.0,
    #     # resistance_form=None, # 0, 1, 2
    #     # # resistance_form=[1, {}],
    #     # # resistance_form=[2, {}],
    #     # calcG_params=None,
    #     # const_L=None,
    #     # kB=KB_1_DEFAULT,
    #     # massman_profile=None,
    #     # verbose=True,
    # )
    # output_vars = ['flag', 'T_S', 'T_C', 'T_AC', 'L_nS', 'L_nC', 'LE_C', 'H_C', 'LE_S', 'H_S', 'G', 'R_S', 'R_x', 'R_A', 'u_friction', 'L', 'n_iterations']
    # for k, v in zip(output_vars, output):
    #     print(k, v)

    # # DEBUG
    point_image_values(t_rad, 't_rad')
    point_image_values(albedo, 'albedo')
    point_image_values(lai, 'lai')

    point_image_values(mask, 'mask')
    point_image_values(rs_1, 'rs_1')
    point_image_values(rs24, 'rs_24')
    point_image_values(le_tot, 'le_tot')
    point_image_values(zs, 'zs')
    point_image_values(f, 'f')
    point_image_values(fc, 'fc')
    point_image_values(hc, 'hc')
    point_image_values(lai_c, 'lai_c')
    point_image_values(fc_q, 'fc_q')
    point_image_values(water_mask, 'water_mask')
    point_image_values(d0, 'd0')
    point_image_values(z0m, 'z0m')
    point_image_values(z0h, 'z0h')
    point_image_values(leaf, 'leaf')
    point_image_values(leaf_c, 'leaf_c')
    point_image_values(leaf_s, 'leaf_s')
    point_image_values(e_s0, 'e_s0')
    point_image_values(vpd, 'vpd')
    point_image_values(e_s, 'e_s')
    point_image_values(Ss, 'Ss')
    point_image_values(lambda1, 'lambda1')
    point_image_values(gamma, 'gamma')
    point_image_values(a_pt, 'a_pt')
    point_image_values(rs_c, 'rs_c')
    point_image_values(rs_s, 'rs_s')
    point_image_values(albedo_c, 'albedo_c')
    point_image_values(albedo_s, 'albedo_s')
    point_image_values(taudl, 'taudl')
    point_image_values(tausolar, 'tausolar')
    point_image_values(e_atm, 'e_atm')
    point_image_values(r_air, 'r_air')
    point_image_values(u_attr, 'u_attr')
    point_image_values(r_ah, 'r_ah')
    point_image_values(r_s, 'r_s')
    point_image_values(r_x, 'r_x')
    point_image_values(t_c, 't_c')
    point_image_values(t_s, 't_s')
    point_image_values(ef_s, 'ef_s')
    input('ENTER')

    # Start Loop for Stability Correction and Water Stress
    # iter_prev = ee.Dictionary({
    #     'a_pt': a_pt,
    #     'ef_s': ef_s,
    #     'g': ee.Image(0),
    #     'rn_c': ee.Image(0),
    #     'rn_s': ee.Image(0),
    #     'r_ah': r_ah,
    #     'r_s': r_s,
    #     'r_x': r_x,
    #     'rs_1': rs_1,
    #     't_ac': ee.Image(0),
    #     't_c': t_c,
    #     't_s': t_s,
    #     'u_attr': u_attr,
    #     'ta': t_air,
    # })

    # Extract inputs from previous iteration
    a_pt_0 = a_pt
    ef_s_0 = ef_s
    #r_ah_0 = r_ah
    #r_s_0 = r_s
    #r_x_0 = r_x
    t_c_0 = t_c
    t_s_0 = t_s
    u_attr_0 = u_attr
    ta_0 = t_air

    # ************************************************************************
    # Iteration 1
    rn_c_1 = tseb_utils.compute_Rn_c(albedo_c, ta_0, t_c_0, t_s_0, e_atm, rs_c, f)
    rn_s_1 = tseb_utils.compute_Rn_s(albedo_s, ta_0, t_c_0, t_s_0, e_atm, rs_s, f)
    rn_1 = rn_c_1.add(rn_s_1)
    g_1 = tseb_utils.compute_G0(rn_1, rn_s_1, ef_s_0, water_mask, lon, datetime)
    h_1 = rn_1.subtract(le_tot).subtract(g_1)
    le_c_1 = Ss.add(gamma).pow(-1).multiply(Ss).multiply(a_pt_0).multiply(f_green).multiply(rn_c_1).max(0)
    h_c_1 = rn_c_1.subtract(le_c_1)
    h_s_1 = h_1.subtract(h_c_1)
    le_s_1 = rn_s_1.subtract(g_1).subtract(h_s_1)
    # # Recompute since it is not returned from fh() or fm()
    # l_ob_1 = u_attr_0.pow(3).multiply(t_rad).multiply(r_air).multiply(cp / -0.41 / 9.806).divide(h_1)
    fh_1 = tseb_utils.compute_stability_fh(h_1, t_rad, u_attr_0, r_air, z_t, d0, cp)
    fm_1 = tseb_utils.compute_stability_fm(h_1, t_rad, u_attr_0, r_air, z_u, d0, z0m, cp)
    fm_h_1 = tseb_utils.compute_stability_fm_h(h_1, t_rad, u_attr_0, r_air, hc, d0, z0m, cp)
    u_attr_1 = tseb_utils.compute_u_attr(u=u, d0=d0, z0m=z0m, z_u=z_u, fm=fm_1)
    r_ah_1 = tseb_utils.compute_r_ah(u_attr=u_attr_1, d0=d0, z0h=z0h, z_t=z_t, fh=fh_1)
    r_s_1 = tseb_utils.compute_r_s(
        u_attr=u_attr_0, t_s=t_s_0, t_c=t_c_0, hc=hc, f=lai, d0=d0, z0m=z0m, leaf=leaf, leaf_s=leaf_s, fm_h=fm_h_1)
    r_x_1 = tseb_utils.compute_r_x(
        u_attr=u_attr_0, hc=hc, f=lai, d0=d0, z0m=z0m, xl=leaf_width, leaf_c=leaf_c, fm_h=fm_h_1)
    t_c_1 = tseb_utils.temp_separation_tc(h_c_1, fc_q, ta_0, t_rad, r_ah_1, r_s_1, r_x_1, r_air, cp)
    t_s_1 = tseb_utils.temp_separation_ts(t_c_1, fc_q, ta_0, t_rad)
    t_ac = tseb_utils.temp_separation_tac(t_c_1, t_s_1, fc_q, ta_0, r_ah_1, r_s_1, r_x_1)
    ta_1 = t_ac.subtract(h_1.multiply(r_ah_1).divide(r_air.multiply(cp)))
    a_pt_1 = a_pt_0.where(le_s_1.lte(0), a_pt_0.subtract(0.05)).where(a_pt_0.lte(0), 0.01)
    den_s_1 = rn_s_1.subtract(g_1)
    den_s_1 = den_s_1.updateMask(den_s_1.neq(0))
    ef_s_1 = le_s_1.divide(den_s_1)

    print('\nIteration 1')
    point_image_values(rn_c_1, 'rn_c')
    point_image_values(rn_s_1, 'rn_s')
    point_image_values(rn_1, 'rn')
    point_image_values(g_1, 'g')
    point_image_values(h_1, 'h')
    point_image_values(le_c_1, 'le_c')
    point_image_values(h_c_1, 'h_c')
    point_image_values(h_s_1, 'h_s')
    point_image_values(le_s_1, 'le_s')
    #point_image_values(l_ob, 'l_ob')
    point_image_values(fh_1, 'fh')
    point_image_values(fm_1, 'fm')
    point_image_values(fm_h_1, 'fm_h')
    point_image_values(u_attr_1, 'u_attr')
    point_image_values(r_ah_1, 'r_ah')
    point_image_values(r_s_1, 'r_s')
    point_image_values(r_x_1, 'r_x')
    point_image_values(t_c_1, 't_c')
    point_image_values(t_s_1, 't_s')
    point_image_values(t_ac, 't_ac')
    point_image_values(a_pt_1, 'a_pt')
    point_image_values(den_s_1, 'den_s')
    point_image_values(ef_s_1, 'ef_s')

    point_image_values(ta_1, 'ta 1')

    # ************************************************************************
    # Iteration 2
    rn_c_2 = tseb_utils.compute_Rn_c(albedo_c, ta_1, t_c_1, t_s_1, e_atm, rs_c, f)
    rn_s_2 = tseb_utils.compute_Rn_s(albedo_s, ta_1, t_c_1, t_s_1, e_atm, rs_s, f)
    rn_2 = rn_c_2.add(rn_s_2)
    g_2 = tseb_utils.compute_G0(rn_2, rn_s_2, ef_s_1, water_mask, lon, datetime)
    h_2 = rn_2.subtract(le_tot).subtract(g_2)
    le_c_2 = Ss.add(gamma).pow(-1).multiply(Ss).multiply(a_pt_1).multiply(f_green).multiply(rn_c_2).max(0)
    h_c_2 = rn_c_2.subtract(le_c_2)
    h_s_2 = h_2.subtract(h_c_2)
    le_s = rn_s_2.subtract(g_2).subtract(h_s_2)
    fh_2 = tseb_utils.compute_stability_fh(h_2, t_rad, u_attr_1, r_air, z_t, d0, cp)
    fm_2 = tseb_utils.compute_stability_fm(h_2, t_rad, u_attr_1, r_air, z_u, d0, z0m, cp)
    fm_h_2 = tseb_utils.compute_stability_fm_h(h_2, t_rad, u_attr_1, r_air, hc, d0, z0m, cp)
    u_attr_2 = tseb_utils.compute_u_attr(u=u, d0=d0, z0m=z0m, z_u=z_u, fm=fm_2)
    r_ah_2 = tseb_utils.compute_r_ah(u_attr=u_attr_2, d0=d0, z0h=z0h, z_t=z_t, fh=fh_2)
    r_s_2 = tseb_utils.compute_r_s(
        u_attr=u_attr_2, t_s=t_s_1, t_c=t_c_1, hc=hc, f=lai, d0=d0, z0m=z0m, leaf=leaf, leaf_s=leaf_s, fm_h=fm_h_2)
    r_x_2 = tseb_utils.compute_r_x(
        u_attr=u_attr_2, hc=hc, f=lai, d0=d0, z0m=z0m, xl=leaf_width, leaf_c=leaf_c, fm_h=fm_h_2)
    t_c_2 = tseb_utils.temp_separation_tc(h_c_2, fc_q, ta_1, t_rad, r_ah_2, r_s_2, r_x_2, r_air, cp)
    t_s_2 = tseb_utils.temp_separation_ts(t_c_2, fc_q, ta_1, t_rad)
    t_ac = tseb_utils.temp_separation_tac(t_c_2, t_s_2, fc_q, ta_1, r_ah_2, r_s_2, r_x_2)
    ta_2 = t_ac.subtract(h_2.multiply(r_ah_2).divide(r_air.multiply(cp)))
    a_pt_2 = a_pt_1.where(le_s.lte(0), a_pt_1.subtract(0.05)).where(a_pt_1.lte(0), 0.01)
    den_s_2 = rn_s_2.subtract(g_2)
    den_s_2 = den_s_2.updateMask(den_s_2.neq(0))
    ef_s_2 = le_s.divide(den_s_2)

    print('\nIteration 2')
    point_image_values(rn_c_2, 'rn_c')
    point_image_values(rn_s_2, 'rn_s')
    point_image_values(rn_2, 'rn')
    point_image_values(g_2, 'g')
    point_image_values(h_2, 'h')
    point_image_values(le_c_2, 'le_c')
    point_image_values(h_c_2, 'h_c')
    point_image_values(h_s_2, 'h_s')
    point_image_values(le_s, 'le_s')
    point_image_values(fh_2, 'fh')
    point_image_values(fm_2, 'fm')
    point_image_values(fm_h_2, 'fm_h')
    point_image_values(u_attr_2, 'u_attr')
    point_image_values(r_ah_2, 'r_ah')
    point_image_values(r_s_2, 'r_s')
    point_image_values(r_x_2, 'r_x')
    point_image_values(t_c_2, 't_c')
    point_image_values(t_s_2, 't_s')
    point_image_values(t_ac, 't_ac')
    point_image_values(a_pt_2, 'a_pt')
    point_image_values(den_s_2, 'den_s')
    point_image_values(ef_s_2, 'ef_s')

    point_image_values(ta_2, 'ta 2')

    # ************************************************************************
    # Iteration 3
    rn_c_3 = tseb_utils.compute_Rn_c(albedo_c, ta_2, t_c_2, t_s_2, e_atm, rs_c, f)
    rn_s_3 = tseb_utils.compute_Rn_s(albedo_s, ta_2, t_c_2, t_s_2, e_atm, rs_s, f)
    rn_3 = rn_c_3.add(rn_s_3)
    g_3 = tseb_utils.compute_G0(rn_3, rn_s_3, ef_s_2, water_mask, lon, datetime)
    h_3 = rn_3.subtract(le_tot).subtract(g_3)
    le_c_3 = Ss.add(gamma).pow(-1).multiply(Ss).multiply(a_pt_2).multiply(f_green).multiply(rn_c_3).max(0)
    h_c_3 = rn_c_3.subtract(le_c_3)
    h_s_3 = h_3.subtract(h_c_3)
    le_s_3 = rn_s_3.subtract(g_3).subtract(h_s_3)
    fh_3 = tseb_utils.compute_stability_fh(h_3, t_rad, u_attr_2, r_air, z_t, d0, cp)
    fm_3 = tseb_utils.compute_stability_fm(h_3, t_rad, u_attr_2, r_air, z_u, d0, z0m, cp)
    fm_h_3 = tseb_utils.compute_stability_fm_h(h_3, t_rad, u_attr_2, r_air, hc, d0, z0m, cp)
    u_attr_3 = tseb_utils.compute_u_attr(u=u, d0=d0, z0m=z0m, z_u=z_u, fm=fm_3)
    r_ah_3 = tseb_utils.compute_r_ah(u_attr=u_attr_3, d0=d0, z0h=z0h, z_t=z_t, fh=fh_3)
    r_s_3 = tseb_utils.compute_r_s(
        u_attr=u_attr_3, t_s=t_s_2, t_c=t_c_2, hc=hc, f=lai, d0=d0, z0m=z0m, leaf=leaf, leaf_s=leaf_s, fm_h=fm_h_3)
    r_x_3 = tseb_utils.compute_r_x(
        u_attr=u_attr_3, hc=hc, f=lai, d0=d0, z0m=z0m, xl=leaf_width, leaf_c=leaf_c, fm_h=fm_h_3)
    t_c_3 = tseb_utils.temp_separation_tc(h_c_3, fc_q, ta_2, t_rad, r_ah_3, r_s_3, r_x_3, r_air, cp)
    t_s_3 = tseb_utils.temp_separation_ts(t_c_3, fc_q, ta_2, t_rad)
    t_ac_3 = tseb_utils.temp_separation_tac(t_c_3, t_s_3, fc_q, ta_2, r_ah_3, r_s_3, r_x_3)
    ta_3 = t_ac.subtract(h_3.multiply(r_ah_3).divide(r_air.multiply(cp)))
    a_pt_3 = a_pt_2.where(le_s.lte(0), a_pt_2.subtract(0.05)).where(a_pt_2.lte(0), 0.01)
    den_s_3 = rn_s_3.subtract(g_3)
    den_s_3 = den_s_3.updateMask(den_s_3.neq(0))
    ef_s_3 = le_s.divide(den_s_3)

    print('\nIteration 3')
    point_image_values(rn_c_3, 'rn_c')
    point_image_values(rn_s_3, 'rn_s')
    point_image_values(rn_3, 'rn')
    point_image_values(g_3, 'g')
    point_image_values(h_3, 'h')
    point_image_values(le_c_3, 'le_c')
    point_image_values(h_c_3, 'h_c')
    point_image_values(h_s_3, 'h_s')
    point_image_values(le_s, 'le_s')
    point_image_values(fh_3, 'fh')
    point_image_values(fm_3, 'fm')
    point_image_values(fm_h_3, 'fm_h')
    point_image_values(u_attr_3, 'u_attr')
    point_image_values(r_ah_3, 'r_ah')
    point_image_values(r_s_3, 'r_s')
    point_image_values(r_x_3, 'r_x')
    point_image_values(t_c_3, 't_c')
    point_image_values(t_s_3, 't_s')
    point_image_values(t_ac_3, 't_ac')
    point_image_values(a_pt_3, 'a_pt')
    point_image_values(den_s_3, 'den_s')
    point_image_values(ef_s_3, 'ef_s')

    point_image_values(ta_3, 'ta 3')

    # ************************************************************************
    # Iteration 4
    rn_c_4 = tseb_utils.compute_Rn_c(albedo_c, ta_3, t_c_3, t_s_3, e_atm, rs_c, f)
    rn_s_4 = tseb_utils.compute_Rn_s(albedo_s, ta_3, t_c_3, t_s_3, e_atm, rs_s, f)
    rn_4 = rn_c_4.add(rn_s_4)
    g_4 = tseb_utils.compute_G0(rn_4, rn_s_4, ef_s_3, water_mask, lon, datetime)
    h_4 = rn_4.subtract(le_tot).subtract(g_4)
    le_c_4 = Ss.add(gamma).pow(-1).multiply(Ss).multiply(a_pt_3).multiply(f_green).multiply(rn_c_4).max(0)
    h_c_4 = rn_c_4.subtract(le_c_4)
    h_s_4 = h_4.subtract(h_c_4)
    le_s_4 = rn_s_4.subtract(g_4).subtract(h_s_4)
    fh_4 = tseb_utils.compute_stability_fh(h_4, t_rad, u_attr_3, r_air, z_t, d0, cp)
    fm_4 = tseb_utils.compute_stability_fm(h_4, t_rad, u_attr_3, r_air, z_u, d0, z0m, cp)
    fm_h_4 = tseb_utils.compute_stability_fm_h(h_4, t_rad, u_attr_3, r_air, hc, d0, z0m, cp)
    u_attr_4 = tseb_utils.compute_u_attr(u=u, d0=d0, z0m=z0m, z_u=z_u, fm=fm_4)
    r_ah_4 = tseb_utils.compute_r_ah(u_attr=u_attr_4, d0=d0, z0h=z0h, z_t=z_t, fh=fh_4)
    r_s_4 = tseb_utils.compute_r_s(
        u_attr=u_attr_4, t_s=t_s_3, t_c=t_c_3, hc=hc, f=lai, d0=d0, z0m=z0m, leaf=leaf, leaf_s=leaf_s, fm_h=fm_h_4)
    r_x_4 = tseb_utils.compute_r_x(
        u_attr=u_attr_4, hc=hc, f=lai, d0=d0, z0m=z0m, xl=leaf_width, leaf_c=leaf_c, fm_h=fm_h_4)
    t_c_4 = tseb_utils.temp_separation_tc(h_c_4, fc_q, ta_3, t_rad, r_ah_4, r_s_4, r_x_4, r_air, cp)
    t_s_4 = tseb_utils.temp_separation_ts(t_c_4, fc_q, ta_3, t_rad)
    t_ac_4 = tseb_utils.temp_separation_tac(t_c_4, t_s_4, fc_q, ta_3, r_ah_4, r_s_4, r_x_4)
    ta_4 = t_ac.subtract(h_4.multiply(r_ah_4).divide(r_air.multiply(cp)))
    a_pt_4 = a_pt_3.where(le_s.lte(0), a_pt_3.subtract(0.05)).where(a_pt_3.lte(0), 0.01)
    den_s_4 = rn_s_4.subtract(g_4)
    den_s_4 = den_s_4.updateMask(den_s_4.neq(0))
    ef_s_4 = le_s.divide(den_s_4)

    print('\nIteration 4')
    point_image_values(rn_c_4, 'rn_c')
    point_image_values(rn_s_4, 'rn_s')
    point_image_values(rn_4, 'rn')
    point_image_values(g_4, 'g')
    point_image_values(h_4, 'h')
    point_image_values(le_c_4, 'le_c')
    point_image_values(h_c_4, 'h_c')
    point_image_values(h_s_4, 'h_s')
    point_image_values(le_s_4, 'le_s')
    point_image_values(fh_4, 'fh')
    point_image_values(fm_4, 'fm')
    point_image_values(fm_h_4, 'fm_h')
    point_image_values(u_attr_4, 'u_attr')
    point_image_values(r_ah_4, 'r_ah')
    point_image_values(r_s_4, 'r_s')
    point_image_values(r_x_4, 'r_x')
    point_image_values(t_c_4, 't_c')
    point_image_values(t_s_4, 't_s')
    point_image_values(t_ac_4, 't_ac')
    point_image_values(a_pt_4, 'a_pt')
    point_image_values(den_s_4, 'den_s')
    point_image_values(ef_s_4, 'ef_s')

    point_image_values(ta_4, 'ta 4')

    # ************************************************************************
    # Iteration 5
    rn_c_5 = tseb_utils.compute_Rn_c(albedo_c, ta_4, t_c_4, t_s_4, e_atm, rs_c, f)
    rn_s_5 = tseb_utils.compute_Rn_s(albedo_s, ta_4, t_c_4, t_s_4, e_atm, rs_s, f)
    rn_5 = rn_c_5.add(rn_s_5)
    g_5 = tseb_utils.compute_G0(rn_5, rn_s_5, ef_s_4, water_mask, lon, datetime)
    h_5 = rn_5.subtract(le_tot).subtract(g_5)
    le_c_5 = Ss.add(gamma).pow(-1).multiply(Ss).multiply(a_pt_4).multiply(f_green).multiply(rn_c_5).max(0)
    h_c_5 = rn_c_5.subtract(le_c_5)
    h_s_5 = h_5.subtract(h_c_5)
    le_s_5 = rn_s_5.subtract(g_5).subtract(h_s_5)
    fh_5 = tseb_utils.compute_stability_fh(h_5, t_rad, u_attr_4, r_air, z_t, d0, cp)
    fm_5 = tseb_utils.compute_stability_fm(h_5, t_rad, u_attr_4, r_air, z_u, d0, z0m, cp)
    fm_h_5 = tseb_utils.compute_stability_fm_h(h_5, t_rad, u_attr_4, r_air, hc, d0, z0m, cp)
    u_attr_5 = tseb_utils.compute_u_attr(u=u, d0=d0, z0m=z0m, z_u=z_u, fm=fm_5)
    r_ah_5 = tseb_utils.compute_r_ah(u_attr=u_attr_5, d0=d0, z0h=z0h, z_t=z_t, fh=fh_5)
    r_s_5 = tseb_utils.compute_r_s(
        u_attr=u_attr_5, t_s=t_s_4, t_c=t_c_4, hc=hc, f=lai, d0=d0, z0m=z0m, leaf=leaf, leaf_s=leaf_s, fm_h=fm_h_5)
    r_x_5 = tseb_utils.compute_r_x(
        u_attr=u_attr_5, hc=hc, f=lai, d0=d0, z0m=z0m, xl=leaf_width, leaf_c=leaf_c, fm_h=fm_h_5)
    t_c_5 = tseb_utils.temp_separation_tc(h_c_5, fc_q, ta_4, t_rad, r_ah_5, r_s_5, r_x_5, r_air, cp)
    t_s_5 = tseb_utils.temp_separation_ts(t_c_5, fc_q, ta_4, t_rad)
    t_ac_5 = tseb_utils.temp_separation_tac(t_c_5, t_s_5, fc_q, ta_4, r_ah_5, r_s_5, r_x_5)
    ta_5 = t_ac.subtract(h_5.multiply(r_ah_5).divide(r_air.multiply(cp)))
    a_pt_5 = a_pt_4.where(le_s.lte(0), a_pt_4.subtract(0.05)).where(a_pt_4.lte(0), 0.01)
    den_s_5 = rn_s_5.subtract(g_5)
    den_s_5 = den_s_5.updateMask(den_s_5.neq(0))
    ef_s_5 = le_s.divide(den_s_5)

    print('\nIteration 5')
    point_image_values(rn_c_5, 'rn_c')
    point_image_values(rn_s_5, 'rn_s')
    point_image_values(rn_5, 'rn')
    point_image_values(g_5, 'g')
    point_image_values(h_5, 'h')
    point_image_values(le_c_5, 'le_c')
    point_image_values(h_c_5, 'h_c')
    point_image_values(h_s_5, 'h_s')
    point_image_values(le_s_5, 'le_s')
    point_image_values(fh_5, 'fh')
    point_image_values(fm_5, 'fm')
    point_image_values(fm_h_5, 'fm_h')
    point_image_values(u_attr_5, 'u_attr')
    point_image_values(r_ah_5, 'r_ah')
    point_image_values(r_s_5, 'r_s')
    point_image_values(r_x_5, 'r_x')
    point_image_values(t_c_5, 't_c')
    point_image_values(t_s_5, 't_s')
    point_image_values(t_ac, 't_ac')
    point_image_values(a_pt_5, 'a_pt')
    point_image_values(den_s_5, 'den_s')
    point_image_values(ef_s_5, 'ef_s')

    point_image_values(ta_5, 'ta 5')

    # ************************************************************************
    # Iteration 6
    rn_c_6 = tseb_utils.compute_Rn_c(albedo_c, ta_5, t_c_5, t_s_5, e_atm, rs_c, f)
    rn_s_6 = tseb_utils.compute_Rn_s(albedo_s, ta_5, t_c_5, t_s_5, e_atm, rs_s, f)
    rn_6 = rn_c_6.add(rn_s_6)
    g_6 = tseb_utils.compute_G0(rn_6, rn_s_6, ef_s_5, water_mask, lon, datetime)
    h_6 = rn_6.subtract(le_tot).subtract(g_6)
    le_c_6 = Ss.add(gamma).pow(-1).multiply(Ss).multiply(a_pt_5).multiply(f_green).multiply(rn_c_6).max(0)
    h_c_6 = rn_c_6.subtract(le_c_6)
    h_s_6 = h_6.subtract(h_c_6)
    le_s_6 = rn_s_6.subtract(g_6).subtract(h_s_6)
    fh_6 = tseb_utils.compute_stability_fh(h_6, t_rad, u_attr_5, r_air, z_t, d0, cp)
    fm_6 = tseb_utils.compute_stability_fm(h_6, t_rad, u_attr_5, r_air, z_u, d0, z0m, cp)
    fm_h_6 = tseb_utils.compute_stability_fm_h(h_6, t_rad, u_attr_5, r_air, hc, d0, z0m, cp)
    u_attr_6 = tseb_utils.compute_u_attr(u=u, d0=d0, z0m=z0m, z_u=z_u, fm=fm_6)
    r_ah_6 = tseb_utils.compute_r_ah(u_attr=u_attr_6, d0=d0, z0h=z0h, z_t=z_t, fh=fh_6)
    r_s_6 = tseb_utils.compute_r_s(
        u_attr=u_attr_6, t_s=t_s_5, t_c=t_c_5, hc=hc, f=lai, d0=d0, z0m=z0m, leaf=leaf, leaf_s=leaf_s, fm_h=fm_h_6)
    r_x_6 = tseb_utils.compute_r_x(
        u_attr=u_attr_6, hc=hc, f=lai, d0=d0, z0m=z0m, xl=leaf_width, leaf_c=leaf_c, fm_h=fm_h_6)
    t_c_6 = tseb_utils.temp_separation_tc(h_c_6, fc_q, ta_5, t_rad, r_ah_6, r_s_6, r_x_6, r_air, cp)
    t_s_6 = tseb_utils.temp_separation_ts(t_c_6, fc_q, ta_5, t_rad)
    t_ac_6 = tseb_utils.temp_separation_tac(t_c_6, t_s_6, fc_q, ta_5, r_ah_6, r_s_6, r_x_6)
    ta_6 = t_ac.subtract(h_6.multiply(r_ah_6).divide(r_air.multiply(cp)))
    a_pt_6 = a_pt_5.where(le_s.lte(0), a_pt_5.subtract(0.05)).where(a_pt_5.lte(0), 0.01)
    den_s_6 = rn_s_6.subtract(g_6)
    den_s_6 = den_s_6.updateMask(den_s_6.neq(0))
    ef_s_6 = le_s.divide(den_s_6)

    print('\nIteration 6')
    point_image_values(rn_c_6, 'rn_c')
    point_image_values(rn_s_6, 'rn_s')
    point_image_values(rn_6, 'rn')
    point_image_values(g_6, 'g')
    point_image_values(h_6, 'h')
    point_image_values(le_c_6, 'le_c')
    point_image_values(h_c_6, 'h_c')
    point_image_values(h_s_6, 'h_s')
    point_image_values(le_s_6, 'le_s')
    point_image_values(fh_6, 'fh')
    point_image_values(fm_6, 'fm')
    point_image_values(fm_h_6, 'fm_h')
    point_image_values(u_attr_6, 'u_attr')
    point_image_values(r_ah_6, 'r_ah')
    point_image_values(r_s_6, 'r_s')
    point_image_values(r_x_6, 'r_x')
    point_image_values(t_c_6, 't_c')
    point_image_values(t_s_6, 't_s')
    point_image_values(t_ac, 't_ac')
    point_image_values(a_pt_6, 'a_pt')
    point_image_values(den_s_6, 'den_s')
    point_image_values(ef_s_6, 'ef_s')

    point_image_values(ta_6, 'ta 6')

    # ************************************************************************

    # Extract the Ta from the iteration output
    ta = ta_6
    ta = ta.updateMask(ta.gte(173.16))
    ta = ta.updateMask(ta.lte(473.16))
    utils.point_image_value(ta, 'ta')

    return ta.rename(['ta'])


def point_image_value(image, xy, scale=None, transform=None, crs=None,):
    """Extract the output value from a calculation at a point"""
    rr_args = {
        'reducer': ee.Reducer.first(),
        'geometry': ee.Geometry.Point(xy),
        'crs': 'EPSG:4326',
        'scale': 1,
    }
    if transform:
        rr_args.update({'crsTransform': transform})
        del rr_args['scale']
    if scale:
        rr_args.update({'scale': scale})
        del rr_args['crsTransform']
    if crs:
        rr_args.update({'crs': crs})

    return ee.Image(image).select([0], ['value']).reduceRegion(**rr_args).get('value').getInfo()


def point_image_values(image, band_name):
    """Extract the output value from a calculation at a point"""
    # print(f'\n{band_name}')
    points = [
        # Bad XY
        [-106.60, 39.72],
        # # Good XYs
        # [-106.56, 39.72],
        # [-106.60, 39.76],
        # [-106.64, 39.72],
        # [-106.60, 39.68],
    ]
    for xy in points:
        print(f'{band_name}: {point_image_value(image, xy, transform=[0.04, 0, -125.02, 0, -0.04, 49.78])}')
