
from sfctools.examples.example_wrapper import Example


def run():

    import os
    import pandas as pd
    import numpy as np
    import matplotlib.pyplot as plt

    from sfctools import Agent, Clock
    from sfctools.api.bimets import BimetsModel, Equations

    # --- SMALL SFCTOOLS MODEL ---
    class Consumer(Agent):
        def __init__(self):
            super().__init__()

            # level of consumption for each time step
            self.c = np.zeros(40)
            self.c2 = np.zeros(40)

            # behavioral coefficients of this agent
            self.a1 = None
            self.a2 = None
            self.a3 = None
            self.a4 = None

        def consume(self):
            # consumption procedure of Consumer

            t = Clock().get_time()
            # this will evaluate a consumption euqation
            Equations().eval("cn", t, verbose=False)
            Equations().eval("TSDELTA(cn2)", t, verbose=False)

    my_consumer = Consumer()  # create a new consumer agent

    class RestOfEconomy(Agent):
        def __init__(self):
            super().__init__()

            # define other variables here (random dummies)
            self.p = np.linspace(12, 24, 40)  # np.random.rand(100)
            self.w1 = np.linspace(28, 54, 40)  # np.random.rand(100)
            self.w2 = np.linspace(2, 9, 40)  # np.random.rand(100)

    rest_of_economy = RestOfEconomy()  # rest of economy

    # --- ESTIMATE COEFFICIENTS FROM DATA ---
    DATA_PATH = "data/example2_data.xlsx"
    MODEL_PATH = "models/example2/"

    # define equation to be estimated, this time only consumption
    equations = {'cn': {  # consumption
        'type': 'BEHAVIORAL',
                'EQ': 'cn = a2*p + a3*TSLAG(p) + a4*(w1+w2)',
                'COEFF': ' a2 a3 a4',
    },
        'cn2': {  # consumption (alternative nonsens)
        'type': 'BEHAVIORAL',
                'EQ': 'TSDELTA(cn2) = b0*TSDELTA(p) + b1*TSDELTA(w1+w2)',
                'COEFF': 'b0 b1',
    }
    }

    # load necessary data
    data = {
        'cn': [39.8, 41.9, 45, 49.2, 50.6, 52.6, 55.1, 56.2, 57.3, 57.8,
               55, 50.9, 45.6, 46.5, 48.7, 51.3, 57.7, 58.7, 57.5, 61.6, 65, 69.7],
        'cn2': [39.8, 41.9, 45, 49.2, 50.6, 52.6, 55.1, 56.2, 57.3, 57.8,
                55, 50.9, 45.6, 46.5, 48.7, 51.3, 57.7, 58.7, 57.5, 61.6, 65, 69.7],
        'p': [12.7, 12.4, 16.9, 18.4, 19.4, 20.1, 19.6, 19.8, 21.1, 21.7,
              15.6, 11.4, 7, 11.2, 12.3, 14, 17.6, 17.3, 15.3, 19, 21.1, 23.5],
        'w1': [28.8, 25.5, 29.3, 34.1, 33.9, 35.4, 37.4, 37.9, 39.2, 41.3,
               37.9, 34.5, 29, 28.5, 30.6, 33.2, 36.8, 41, 38.2, 41.6, 45, 53.3],
        'w2': [2.2, 2.7, 2.9, 2.9, 3.1, 3.2, 3.3, 3.6, 3.7, 4, 4.2, 4.8,
               5.3, 5.6, 6, 6.1, 7.4, 6.7, 7.7, 7.8, 8, 8.5]
    }

    y0 = 1920
    yf = 1920 + len(data['cn'])
    years = list(range(y0, yf))
    df = pd.DataFrame(data, index=years)
    df.index.name = 'Year'
    df.to_excel(DATA_PATH)

    # Define model
    model = BimetsModel(equations)

    model.read_data(DATA_PATH)
    model.gen_model(MODEL_PATH, year_start_est=y0+1, year_end_est=yf-1,
                    year_start_sim=y0+1, year_end_sim=yf-1)

    model.run()
    model.print_summary("cn")
    model.print_summary("cn2")
    # NOTE you can comment the estimation out once done

    # --- LINK AGENT AND EMPIRICS ---
    Equations().init_paths(DATA_PATH, MODEL_PATH)

    # link attribute 'c' to variable 'cn'
    Equations().link_attr('cn', my_consumer, 'c')
    Equations().link_attr('TSDELTA(cn2)', my_consumer, 'c2')

    # other variables are defined in the model elsewhere.
    # here: dummy arrays. transfer_data=False to not
    # overwrite the agent's internal values with the original data
    Equations().link_attr('p', rest_of_economy, transfer_data=False)
    Equations().link_attr('w1', rest_of_economy, transfer_data=False)
    Equations().link_attr('w2', rest_of_economy, transfer_data=False)

    # model loop (only consumption)
    Clock().tick()  # set t = 1

    for __ in range(39):
        my_consumer.consume()
        # do more here if you like...
        Clock().tick()

    # plot results
    try:
        plt.style.use("sfctools")
    except:
        pass

    plt.figure(figsize=(6, 3.33))
    years = np.arange(40)+y0
    plt.plot(df.index, df["cn"], label="Data")
    plt.plot(years[1:], my_consumer.c[1:], label="Model")
    plt.plot(years[1:], my_consumer.c2[1:],
             label="Model (alternative equation)")
    plt.gca().set_xticks([i for i in years if i % 20 == 0])
    plt.gca().set_xticklabels(["%i" % i for i in years if i % 20 == 0])
    plt.xlabel("Year")
    plt.title("Private Consumption Expenditure (Example 2)")
    plt.ylabel("(bn USD)")
    plt.legend(fancybox=False)
    plt.tight_layout()
    plt.show()


class BimetsCustomExample(Example):
    def __init__(self):
        super().__init__(lambda: run())


if __name__ == "__main__":
    my_instance = BimetsCustomExample()
    my_instance.run()
