import numpy as np
import pandas as pd
# from PetThermoTools.Barom import *
# from PetThermoTools.Liq import *
# from PetThermoTools.Crystallise import *
from PetThermoTools.MELTS import *
from PetThermoTools.Compositions import *
# try:
#     from PetThermoTools.Holland import *
# except:
#     pass

Names = {'liquid1': '_Liq',
        'olivine1': '_Ol',
        'orthopyroxene1': '_Opx',
        'clinopyroxene1': '_Cpx',
        'garnet1': '_Grt',
        'spinel1': '_Sp',
        'k-feldspar1': '_Kspar',
        'quartz1': '_Qtz',
        'rhm-oxide1': '_Rhm',
        'apatite1': '_Apa',
        'olivine2': '_Ol2',
        'plagioclase1': '_Plag',
        'clinopyroxene2': '_Cpx2',
        'plagioclase2': '_Plag2',
        'spinel2': '_Sp2',
        'k-feldspar2': '_Kspar2',
        'garnet2': '_Grt2',
        'rhm-oxide2': '_Rhm2',
        'quartz2': '_Qtz2',
        'orthopyroxene2': '_Opx2',
        'apatite2': '_Apa2',
        'liquid2': '_Liq2',
        'liquid3': '_Liq3',
        'liquid4': '_Liq4'}

Names_MM = {'liq1': '_Liq',
            'ol1': '_Ol',
            'opx1': '_Opx',
            'cpx1': '_Cpx',
            'g1': '_Grt',
            'spl1': '_Sp',
            'fsp1': '_Plag',
            'ol2': '_Ol2',
            'cpx2': '_Cpx2',
            'opx2': '_Opx2',
            'g2': '_Grt2',
            'fsp2': '_Plag2',
            'spl2': '_Sp2',
            'liq2': '_Liq2',
            'liq3': '_Liq3',
            'liq4': '_Liq4'}

def to_float(x):
    if x is None:
        return None
    if isinstance(x, (int, float)):
        return float(x)
    if isinstance(x, (list, tuple)):
        return [float(v) for v in x]
    if isinstance(x, np.ndarray):
        return x.astype(float)
    return x  # leave unchanged if unexpected type

def label_results(Result,label):
    Results = Result.copy()
    new_out = {}
    if  label == "CO2":
        for r in Results:
            new_out['CO2 = ' + str(Results[r]['Input']['comp']['CO2_Liq']) + ' wt%'] = Results[r].copy()
        new_out = dict(sorted(new_out.items(), key=lambda x: float(x[0].split('=')[1].split(' ')[1])))
    elif label == "pressure" or label == "P" or label == "P_bar":
        for r in Results:
            new_out['P = ' + str(Results[r]['Input']['P_bar']) + ' bars'] = Results[r].copy()
        new_out = dict(sorted(new_out.items(), key=lambda x: float(x[0].split('=')[1].split(' ')[1])))
    elif label == "fO2":
        for r in Results:
            new_out['fO2 = ' + Results[r]['Input']['fO2_buffer'] + ' ' + str(round(Results[r]['Input']['fO2_offset'],2))] = Results[r].copy()
        new_out = dict(sorted(new_out.items(), key=lambda x: float(x[0].split('=')[1].split(' ')[2])))
    elif label == 'H2O':
        for r in Results:
            new_out['H2O = ' + str(Results[r]['Input']['comp']['H2O_Liq']) + ' wt%'] = Results[r].copy()
        new_out = dict(sorted(new_out.items(), key=lambda x: float(x[0].split('=')[1].split(' ')[1])))
    
    if len(new_out) == 0:
        new_out = Results.copy()
    
    return new_out

def supCalc(Model = "MELTSv1.0.2", bulk = None, phase = None, T_C = None, P_bar = None,
             Fe3Fet_Liq = None, H2O_Liq = None, CO2_Liq = None, fO2_buffer = None, fO2_offset = None, 
             melts = None):
    
    comp = bulk.copy()

    if type(comp) == pd.core.series.Series:
        comp = comp.to_dict()
    
    comp = comp_fix(Model = Model, comp = comp, H2O_Liq = H2O_Liq, CO2_Liq = CO2_Liq, Fe3Fet_Liq = Fe3Fet_Liq)

    Results = supCalc_MELTS(Model = "MELTSv1.0.2", comp = comp, phase = phase, T_C = T_C, P_bar = P_bar,
             fO2_buffer = fO2_buffer, fO2_offset = fO2_offset, 
             melts = melts)
    
    return Results

def comp_check(comp_lith, Model, MELTS_filter, Fe3Fet):
    if type(comp_lith) == str:
        if Model != "pyMelt":
            comp = Compositions[comp_lith]
        else:
            comp = comp_lith
    else:
        comp = comp_lith.copy()

    # if comp is entered as a pandas series, it must first be converted to a dict
    if Model != "pyMelt":
        if type(comp) == pd.core.series.Series:
            comp = comp.to_dict()

        comp = comp_fix(Model = Model, comp = comp, Fe3Fet_Liq = Fe3Fet)

    if "MELTS" in Model and MELTS_filter == True:
        if type(comp) == pd.core.frame.DataFrame:
            comp['K2O_Liq'] = np.zeros(len(comp['SiO2_Liq']))
            comp['P2O5_Liq'] = np.zeros(len(comp['SiO2_Liq']))
            comp['H2O_Liq'] = np.zeros(len(comp['SiO2_Liq']))
            comp['CO2_Liq'] = np.zeros(len(comp['SiO2_Liq']))
        else:
            comp['K2O_Liq'] = 0
            comp['P2O5_Liq'] = 0
            comp['H2O_Liq'] = 0
            comp['CO2_Liq'] = 0
    
    return comp


def comp_fix(Model = None, comp = None, Fe3Fet_Liq = None, H2O_Liq = None, CO2_Liq = None):
    '''
    Ensure that the input variables contain the correct column headers for the following variables.

    Parameters:
    ----------
    Model: string
        "MELTSvx.x.x" or "Holland" determines which function list is followed.

    comp: dict or DataFrame
        inputed composition for calculations

    Fe3Fet_Liq: float or np.ndarray
        Fe 3+/total ratio. If type(comp) == dict, and type(Fe3Fet_Liq) == np.ndarray a new DataFrame will be constructed with bulk compositions varying only in their Fe3Fet_Liq value. If comp is a pd.DataFrame, a single Fe3Fet_Liq value may be passed (float) and will be used as the Fe redox state for all starting compostions, or an array of Fe3Fet_Liq values, equal to the number of compositions specified in comp can specify a different Fe redox state for each sample. If None, the Fe redox state must be specified in the comp variable or an oxygen fugacity buffer must be chosen.

    H2O_Liq: float or np.ndarray
        H2O content of the initial melt phase. If type(comp) == dict, and type(H2O_Liq) = np.ndarray a new DataFrame will be constructed with bulk compositions varying only in their H2O_Liq value. If comp is a pd.DataFrame, a single H2O_Liq value may be passes (float) and will be used as the initial melt H2O content for all starting compositions. Alternatively, if an array of H2O_Liq values is passed, equal to the number of compositions specified in comp, a different initial melt H2O value will be passed for each sample. If None, H2O_Liq must be specified in the comp variable.

    Returns:
    ---------
    comp: dict or DataFrame
        new composition file with correct headers.
    '''
    if Model is None:
        Model = "MELTSv1.0.2"

    Comp_start = comp.copy()
    if "FeO_Liq" in list(Comp_start.keys()) and "Fe2O3_Liq" in list(Comp_start.keys()):
        if "FeOt_Liq" not in list(Comp_start.keys()):
            comp['FeOt_Liq'] = comp['FeO_Liq'] + 71.844/(159.69/2)*comp['Fe2O3_Liq']
        if "Fe3Fet_Liq" not in list(Comp_start.keys()):
            comp['Fe3Fet_Liq'] = (1 - comp['FeO_Liq']/(comp['FeO_Liq'] + 71.844/(159.69/2)*comp['Fe2O3_Liq']))
        Comp_start = comp.copy()

    if "FeO" in list(Comp_start.keys()) and "Fe2O3" in list(Comp_start.keys()):
        if "FeOt" not in list(Comp_start.keys()):
            comp['FeOt'] = comp['FeO'] + 71.844/(159.69/2)*comp['Fe2O3']
        if"Fe3Fet" not in list(Comp_start.keys()):
            comp['Fe3Fet'] = 1 - comp['FeO']/(comp['FeO'] + 71.844/(159.69/2)*comp['Fe2O3'])
        Comp_start = comp.copy()

    if "MELTS" in Model:
        # check all required columns are present with appropriate suffix
        Columns_bad = ['SiO2', 'TiO2', 'Al2O3', 'Cr2O3', 'FeOt', 'MnO', 'MgO', 'CaO', 'Na2O', 'K2O', 'P2O5', 'H2O', 'CO2', 'Fe3Fet']
        Columns_ideal = ['SiO2_Liq', 'TiO2_Liq', 'Al2O3_Liq', 'Cr2O3_Liq', 'FeOt_Liq', 'MnO_Liq', 'MgO_Liq', 'CaO_Liq', 'Na2O_Liq', 'K2O_Liq', 'P2O5_Liq', 'H2O_Liq', 'CO2_Liq', 'Fe3Fet_Liq']

        if type(comp) == pd.core.frame.DataFrame:
            for el in Comp_start:
                if el in Columns_bad:
                    comp = comp.rename(columns = {el:el + '_Liq'})

            for el in Columns_ideal:
                if el not in list(comp.keys()):
                    comp[el] = np.zeros(len(comp.iloc[:,0]))

        elif type(comp) == dict:
            for el in Comp_start:
                if el in Columns_bad:
                    comp[el + '_Liq'] = comp[el]
                    del comp[el]

            for el in Columns_ideal:
                if el not in list(comp.keys()):
                    comp[el] = 0.0
    else:
        # check all required columns are present with appropriate suffix
        Columns_bad = ['SiO2', 'TiO2', 'Al2O3', 'FeOt', 'MgO', 'CaO', 'Na2O', 'K2O', 'H2O', 'Cr2O3', 'Fe3Fet']
        Columns_ideal = ['SiO2_Liq', 'TiO2_Liq', 'Al2O3_Liq', 'FeOt_Liq', 'MgO_Liq', 'CaO_Liq', 'Na2O_Liq', 'K2O_Liq', 'Cr2O3_Liq', 'H2O_Liq', 'Fe3Fet_Liq']
        Comp_start = comp.copy()
        if type(comp) == pd.core.frame.DataFrame:
            for el in Comp_start:
                if el in Columns_bad:
                    comp = comp.rename(columns = {el:el + '_Liq'})

            for el in Columns_ideal:
                if el not in list(comp.keys()):
                    comp[el] = np.zeros(len(comp.iloc[:,0]))

        elif type(comp) == dict:
            for el in Comp_start:
                if el in Columns_bad:
                    comp[el + '_Liq'] = comp[el]
                    del comp[el]

            for el in Columns_ideal:
                if el not in list(comp.keys()):
                    comp[el] = 0.0

    # set the liquid Fe redox state if specified separate to the bulk composition
    if Fe3Fet_Liq is not None:
        if type(comp) == dict:
            if type(Fe3Fet_Liq) != np.ndarray:
                comp['Fe3Fet_Liq'] = Fe3Fet_Liq
            else:
                Comp = pd.DataFrame.from_dict([comp]*len(Fe3Fet_Liq))
                Comp['Fe3Fet_Liq'] = Fe3Fet_Liq
                comp = Comp.copy()
        else:
            comp['Fe3Fet_Liq'] = np.zeros(len(comp.iloc[:,0])) + Fe3Fet_Liq


    if H2O_Liq is not None:
        if type(comp) == dict:
            if type(H2O_Liq) != np.ndarray:
                comp['H2O_Liq'] = H2O_Liq
            else:
                Comp = pd.DataFrame.from_dict([comp]*len(H2O_Liq))
                Comp['H2O_Liq'] = H2O_Liq
                comp = Comp.copy()
        else:
            comp['H2O_Liq'] = np.zeros(len(comp.iloc[:,0])) + H2O_Liq

    if CO2_Liq is not None:
        if type(comp) == dict:
            if type(CO2_Liq) != np.ndarray:
                comp['CO2_Liq'] = CO2_Liq
            else:
                Comp = pd.DataFrame.from_dict([comp]*len(CO2_Liq))
                Comp['CO2_Liq'] = CO2_Liq
                comp = Comp.copy()
        else:
            comp['CO2_Liq'] = np.zeros(len(comp.iloc[:,0])) + CO2_Liq

    return comp

def stich(Res, multi = None, Model = None, Frac_fluid = None, Frac_solid = None):
    '''
    Takes the outputs from the multiple crystallisation/decompression calculations and stiches them together into a single dataframe. Additionally, it adds the relevant suffix to the composition and properties of each mineral (e.g., SiO2 -> SiO2_Liq for the liquid phase).

    Parameters:
    ----------
    Res: dict
        Final results from the multiple crystallisation/decompression calculations.

    multi: True/False
        If True, Results is composed of multiple dictionaries, each representing a single crystallisation/decompression calculation. Default is False.

    Model: string
        "MELTSvx.x.x" or "Holland" determines which function list is followed.

    Returns:
    ----------
    Results: dict
        A copy of the input dict with a new DataFrame titled 'All' included.
    '''
    Results = Res.copy()
    if "MELTS" in Model:
        Order = ['SiO2', 'TiO2', 'Al2O3', 'Cr2O3', 'Fe2O3', 'FeO', 'FeOt', 'MnO', 'MgO', 'CaO', 'Na2O', 'K2O', 'P2O5', 'H2O', 'CO2', 'Fe3Fet']
        if multi is None:
            Results = stich_work(Results = Results, Order = Order, Model = "MELTS", Frac_fluid = Frac_fluid, Frac_solid = Frac_solid)
        else:
            for Ind in Res:
                Result = Res[Ind].copy()
                Result = stich_work(Results = Result, Order = Order, Model = "MELTS", Frac_fluid = Frac_fluid, Frac_solid = Frac_solid)
                Results[Ind] = Result.copy()
    else:
        if Model == "Weller2024":
            Order = ['SiO2', 'TiO2', 'Al2O3', 'Cr2O3', 'FeOt', 'MgO', 'CaO', 'Na2O', 'K2O', 'Fe3Fet']
        else:
            Order = ['SiO2', 'TiO2', 'Al2O3', 'Cr2O3', 'FeOt', 'MgO', 'CaO', 'Na2O', 'K2O', 'H2O', 'Fe3Fet']

        if multi is None:
            Results = stich_work(Results = Results, Order = Order, Model = "Holland", Frac_fluid = Frac_fluid, Frac_solid = Frac_solid)
        else:
            for Ind in Res:
                Result = Res[Ind].copy()
                Result = stich_work(Results = Result, Order = Order, Model = "Holland", Frac_fluid = Frac_fluid, Frac_solid = Frac_solid)
                Results[Ind] = Result.copy()

    return Results

def stich_work(Results = None, Order = None, Model = "MELTS", Frac_fluid = None, Frac_solid = None):
    '''
    Does the work required by Stich.
    '''
    Res = Results.copy()
    if "MELTS" in Model:    
        Results['Conditions'] = Results['Conditions'].rename(columns = {'temperature':'T_C'})
        Results['Conditions'] = Results['Conditions'].rename(columns = {'pressure':'P_bar'})
    else:
        Results['Conditions'] = Results['Conditions'].rename(columns = {'P_kbar': 'P_bar'})
        Results['Conditions']['P_bar'] = Results['Conditions']['P_bar']*1000

    SN = []
    for R in Results:
        if '_prop' not in R and R != 'Conditions' and R != "sys":
            SN += [R]

    for R in SN:
        # if "_prop" not in R and R != "Conditions" and R!= "sys":
        if "MELTS" in Model:
            Results[R].loc[:,'FeOt'] = Results[R].loc[:,'FeO'] + 71.844/(159.69/2)*Results[R].loc[:,'Fe2O3']
            Results[R].loc[:,'Fe3Fet'] = (71.844/(159.69/2)*Results[R].loc[:,'Fe2O3'])/Results[R].loc[:,'FeOt']
            try:
                Results[R][Results[R + '_prop']['mass'] == 0.0] = np.nan
            except:
                Results[R][Results[R + '_prop']['Mass'] == 0.0] = np.nan
            Results[R] = Results[R][Order]
            if R == "fluid1":
                El = ['SiO2', 'TiO2', 'Al2O3', 'Cr2O3', 'FeO', 'Fe2O3', 'FeOt', 'Fe3Fet',
                      'MnO','MgO', 'CaO', 'Na2O', 'K2O', 'P2O5']
                for e in El:
                    Results[R] = Results[R].drop(columns = e)
                Results[R].loc[:,'X_H2O_mol'] = (Results[R].loc[:, 'H2O']/18)/(Results[R].loc[:, 'H2O']/18 + Results[R].loc[:, 'CO2']/44)
                Results[R].loc[:,'X_CO2_mol'] = 1 - Results[R].loc[:, 'X_H2O_mol']
        else:
            Results[R] = Results[R].rename(columns = {'FeO': 'FeOt'})
            Tot = Results[R].sum(axis = 1)
            for el in Results[R]:
                Results[R][el] = 100*Results[R][el]/Tot
            Results[R]['Fe3Fet'] = Results[R]['O']/(((159.9/2)/71.844)*Results[R]['FeOt'] - Results[R]['FeOt'])
            try:
                Results[R][Results[R + '_prop']['mass'] == 0.0] = np.nan
            except:
                Results[R][Results[R + '_prop']['Mass'] == 0.0] = np.nan
                
            Results[R] = Results[R][Order]

    if "MELTS" in Model:
        Remove = np.where(Results['Conditions']['h'] == 0.0)[0]
    else:
        Remove = np.where(Results['Conditions']['T_C'] == 0.0)[0]

    for R in Results:
        if len(Remove) > 0:
            Results[R] = Results[R].drop(labels = Remove)

    # if "MELTS" in Model:
    #     Remove_2 = np.where(Results['Conditions']['h'] == np.nan)[0]
    #     for R in Results:
    #         if len(Remove_2) > 0:
    #             Results[R] = Results[R].drop(labels = Remove_2)

    Results_Mass = pd.DataFrame(data = np.zeros((len(Results['Conditions']['T_C']), len(SN))), columns = SN)
    Results_Volume = Results_Mass.copy()
    Results_rho = Results_Mass.copy()
    for n in SN:
        try:
            Results_Mass[n] = Results[n + '_prop']['mass']
        except:
            Results_Mass[n] = Results[n + '_prop']['Mass']

        if "MELTS" in Model:
            Results_Volume[n] = Results[n + '_prop']['v']
            Results_rho[n] = Results[n + '_prop']['rho']

    if Frac_solid is True or Frac_fluid is True:
        if Frac_solid is None:
            Results_Mass['fluid1_cumsum'] = Results_Mass['fluid1'].cumsum()
        elif Frac_fluid is None:
            for n in SN:
                if n != 'liquid1' and n!= 'fluid1' and n != 'liq1' and n != 'fl1':
                    Results_Mass[n + '_cumsum'] = Results_Mass[n].cumsum()
            if 'liq1' in SN:
                Results_Mass[n + '_cumsum'] = Results_Mass.loc[0, 'liq1'] - Results_Mass.loc[:,Results_Mass.columns.str.contains('_cumsum')].sum(axis = 1)
            elif 'liquid1' in SN:
                Results_Mass[n + '_cumsum'] = Results_Mass.loc[0, 'liquid1'] - Results_Mass.loc[:,Results_Mass.columns.str.contains('_cumsum')].sum(axis = 1)
        else:
            for n in SN:
                if n != 'liquid1' and n != 'liq1':
                    Results_Mass[n + '_cumsum'] = Results_Mass[n].cumsum()
            if 'liq1' in SN:
                Results_Mass[n + '_cumsum'] = Results_Mass.loc[0, 'liq1'] - Results_Mass.loc[:,Results_Mass.columns.str.contains('_cumsum')].sum(axis = 1)
            elif 'liquid1' in SN:
                Results_Mass[n + '_cumsum'] = Results_Mass.loc[0, 'liquid1'] - Results_Mass.loc[:,Results_Mass.columns.str.contains('_cumsum')].sum(axis = 1)

    Results_All = Results['Conditions'].copy()
    for R in Results:
        if R != "Conditions" and R != "sys":
            if "MELTS" in Model:
                if any(n in R for n in Names):
                    for n in Names:
                        if n in R:
                            Results[R] = Results[R].add_suffix(Names[n])                
                else:
                    if '_prop' in R:
                        Results[R] = Results[R].add_suffix('_' + R[:-5])
                    else:
                        Results[R] = Results[R].add_suffix('_' + R)
            else:
                if any(n in R for n in Names_MM):
                    for n in Names_MM:
                        if n in R:
                            Results[R] = Results[R].add_suffix(Names_MM[n])
                else:       
                    if '_prop' in R:
                        Results[R] = Results[R].add_suffix('_' + R[:-5])
                    else:
                        Results[R] = Results[R].add_suffix('_' + R)
            
    # for R in Results:
    #     if R != "Conditions" and R != "sys":
    #         if "MELTS" in Model:
    #             if R in Names:  # Check for exact match in Names
    #                 Results[R] = Results[R].add_suffix(Names[R])                
    #             else:
    #                 if '_prop' in R:
    #                     Results[R] = Results[R].add_suffix('_' + R[:-5])
    #                 else:
    #                     Results[R] = Results[R].add_suffix('_' + R)
    #         else:
    #             if R in Names_MM:  # Check for exact match in Names_MM
    #                 Results[R] = Results[R].add_suffix(Names_MM[R])
    #             else:       
    #                 if '_prop' in R:
    #                     Results[R] = Results[R].add_suffix('_' + R[:-5])
    #                 else:
    #                     Results[R] = Results[R].add_suffix('_' + R)


            Results_All = pd.concat([Results_All, Results[R]], axis = 1)

    Results['All'] = Results_All
    Results['Mass'] = Results_Mass
    if "MELTS" in Model:
        Results['Volume'] = Results_Volume
        Results['rho'] = Results_rho

    if Results['Mass'].sum(axis = 1).iloc[-1] == 0.0:
        for R in Results:
            Results[R].drop(Results[R].tail(1).index,inplace=True)

    return Results