

from sfctools.examples.example_wrapper import Example


def run():

    import pandas as pd
    from sfctools.api.bimets import BimetsModel

    """ 
    Klein model
    original source: https://cran.r-project.org/web/packages/bimets/vignettes/bimets.pdf

    cn: Private Consumption Expenditure
    i: Investment
    w1: Wage Bill of the Private Sector (Demand for Labor)
    p: Profits
    k: Stock of Capital Goods
    y: Gross National Product
    w2: Wage Bill of the Government Sector
    time: index of the passage of time
    g: Government Expenditure plus Net Exports
    t: business taxes

    a1, a2, a3, a4, b1, b2, b3, b4, c1, c2, c3, c4 are coefficients to be estimated
    """

    # Define paths
    DATA_PATH = "data/klein_data.xlsx"
    MODEL_PATH = "models/klein/"

    # Define equations
    equations = {
        'cn': {  # consumption
            'type': 'BEHAVIORAL',
            'EQ': 'cn = a1 + a2*p + a3*TSLAG(p) + a4*(w1+w2)',
            'COEFF': 'a1 a2 a3 a4',
        },
        'i': {  # investment
            'type': 'BEHAVIORAL',
            'EQ': 'i = b1 + b2*p + b3*TSLAG(p) + b4*TSLAG(k)',
            'COEFF': 'b1 b2 b3 b4'
        },
        'w1': {  # demand for labor
            'type': 'BEHAVIORAL',
            'EQ': 'w1 = c1 + c2*(y+t-w2) + c3*TSLAG(y+t-w2) + c4*time',
            'COEFF': 'c1 c2 c3 c4',
        },
        'y': {  # gross national product
            'type': 'IDENTITY',
            'EQ': 'y = cn + i + g - t'
        },
        'p': {  # profits
            'type': 'IDENTITY',
            'EQ': 'p = y - (w1+w2)'
        },
        'k': {  # capital stock
            'type': 'IDENTITY',
            'EQ': 'k = TSLAG(k) + i'
        },
    }

    # Define data
    data = {
        'cn': [39.8, 41.9, 45, 49.2, 50.6, 52.6, 55.1, 56.2, 57.3, 57.8,
               55, 50.9, 45.6, 46.5, 48.7, 51.3, 57.7, 58.7, 57.5, 61.6, 65, 69.7],
        'g': [4.6, 6.6, 6.1, 5.7, 6.6, 6.5, 6.6, 7.6, 7.9, 8.1, 9.4, 10.7,
              10.2, 9.3, 10, 10.5, 10.3, 11, 13, 14.4, 15.4, 22.3],
        'i': [2.7, -0.2, 1.9, 5.2, 3, 5.1, 5.6, 4.2, 3, 5.1, 1, -3.4, -6.2,
              -5.1, -3, -1.3, 2.1, 2, -1.9, 1.3, 3.3, 4.9],
        'k': [182.8, 182.6, 184.5, 189.7, 192.7, 197.8, 203.4, 207.6,
              210.6, 215.7, 216.7, 213.3, 207.1, 202, 199, 197.7, 199.8,
              201.8, 199.9, 201.2, 204.5, 209.4],
        'p': [12.7, 12.4, 16.9, 18.4, 19.4, 20.1, 19.6, 19.8, 21.1, 21.7,
              15.6, 11.4, 7, 11.2, 12.3, 14, 17.6, 17.3, 15.3, 19, 21.1, 23.5],
        'w1': [28.8, 25.5, 29.3, 34.1, 33.9, 35.4, 37.4, 37.9, 39.2, 41.3,
               37.9, 34.5, 29, 28.5, 30.6, 33.2, 36.8, 41, 38.2, 41.6, 45, 53.3],
        'y': [43.7, 40.6, 49.1, 55.4, 56.4, 58.7, 60.3, 61.3, 64, 67, 57.7,
              50.7, 41.3, 45.3, 48.9, 53.3, 61.8, 65, 61.2, 68.4, 74.1, 85.3],
        't': [3.4, 7.7, 3.9, 4.7, 3.8, 5.5, 7, 6.7, 4.2, 4, 7.7, 7.5, 8.3, 5.4,
              6.8, 7.2, 8.3, 6.7, 7.4, 8.9, 9.6, 11.6],
        'time': [None, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0,
                 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
        'w2': [2.2, 2.7, 2.9, 2.9, 3.1, 3.2, 3.3, 3.6, 3.7, 4, 4.2, 4.8,
               5.3, 5.6, 6, 6.1, 7.4, 6.7, 7.7, 7.8, 8, 8.5]
    }

    y0 = 1920
    yf = 1920 + len(data['cn'])
    years = list(range(y0, yf))
    df = pd.DataFrame(data, index=years)
    df.index.name = 'Year'
    df.to_excel(DATA_PATH)

    # Define model
    model = BimetsModel(equations)
    model.read_data(DATA_PATH)
    model.gen_model(MODEL_PATH, year_start_est=y0+1, year_end_est=yf-1,
                    year_start_sim=y0+1, year_end_sim=yf-1)

    # run model
    model.run()

    # print results
    model.print_summary()

    # visualize the results
    import matplotlib.pyplot as plt
    try:
        plt.style.use("sfctools")
        # (optional) matplotlib style for sfctools
    except:
        pass

    df = model.fetch_output()
    # print(df)

    plt.figure(figsize=(6, 3.33))
    plt.plot(df.index, df["y"], label="Model")
    plt.plot(years, data["y"], label="Data", linestyle="dotted", marker="^")
    plt.gca().set_xticks([i for i in df.index if i % 2 == 0])
    plt.gca().set_xticklabels(["%i" % i for i in df.index if i % 2 == 0])
    plt.xlabel("Year")
    plt.title("Gross National Product")
    plt.ylabel("(bn USD)")
    plt.legend(fancybox=False)
    plt.tight_layout()
    plt.show()


class BimetsKleinExample(Example):
    def __init__(self):
        super().__init__(lambda: run())


if __name__ == "__main__":
    my_instance = BimetsKleinExample()
    my_instance.run()
