import os
import re
import math
import warnings
import subprocess
from collections import defaultdict

import numpy as np
import pandas as pd

from sfctools.core.singleton import Singleton
from sfctools.config import read_config
from sfctools import Clock


"""
sfctools configuration for Bimets_Model.R
WARNING THIS IS A BETA FEATURE. NO WARRANTY OF ANY KIND

__author__: TB, DLR-VE
__date__: Feb24

The resulting folder structure of the bimets model should be as follows:

    Bimets
        |____ data
            |___ my_data1.xlsx
            |___ my_data2.xlsx
            |___ my_data3.xlsx

        |____ models
            |_______ model_name
                    |___ model.txt         <- BimetsModel.model_txt
                    |____ data_prep.R      <- BimetsModel.dataprep_str
                    |____ write_output.R   <- BimetsModel.write_out_str
                    |____ bimets_model.R   <- BimetsModel.model_code_est
"""

def _fmt_coeff_names(names):
        if names is None:
            return None
        if isinstance(names, str):
            # already a single string like "alpha beta1 beta2"
            return names.strip()
        if isinstance(names, (list, tuple)):
            # convert ['alpha','beta1','beta2'] -> "alpha beta1 beta2"
            return " ".join(str(n).strip() for n in names if str(n).strip())
        # last resort
        return str(names).strip()   



def gen_DataPrep(fname, data_args, skiprows=0, omit_nans=False):
    """
    Generate R data prep with robust numeric coercion for TIMESERIES()
    and quiet/safe name handling.
    """
    import pandas as pd
    df = pd.read_excel(fname)
    df = df.iloc[skiprows:, :]

    # Build an ordered, de-duplicated list of columns plus any extra data_args
    cols = list(df.columns)
    for extra in data_args.keys():
        if extra not in cols:
            cols.append(extra)

    def backtick(s: str) -> str:
        s = str(s).replace("`", "\\`")  # escape backticks for R
        return f"`{s}`"

    my_data_args = []
    for col in cols:
        lhs = backtick(col)  # list element name
        if col in data_args:
            # user-supplied R expression/string (leave as-is)
            my_data_args.append(f"{lhs} = tser({data_args[col]})")
        else:
            rhs = backtick(col)  # df column accessor
            my_data_args.append(f"{lhs} = tser(df${rhs})")

    start_year = int(pd.to_numeric(df["Year"], errors="coerce").min())
    str_data_args = ",\n".join(["\t\t" + i for i in my_data_args])
    str_omit = "df <- na.omit(df)" if omit_nans else ""
    fname_fixed = fname.replace("\\", "/")

    # NOTE: keep as a raw f-string; braces belong to R, {start_year} is Python
    my_str = rf"""
suppressPackageStartupMessages({{
  library(mFilter)
  library(bimets)
  library(knitr)
  library(readxl)
  library(patchwork)
  library(tidyr)
  library(readr)
  library(ggplot2)
  library(ggthemes)
  library(scales)
  library(zoo)
  library(gridExtra)
  library(openxlsx)
}})

tser <- function(xx){{
  # Flatten list-columns
  if (is.list(xx)) xx <- unlist(xx, use.names = FALSE)
  # Normalize to character, strip NBSP/whitespace, fix decimal comma
  xx <- as.character(xx)
  xx <- gsub("\\u00A0", "", xx)      # non-breaking space
  xx <- trimws(xx)
  xx <- gsub(",", ".", xx, fixed = TRUE)
  # Coerce to numeric (NA where not parseable)
  xx <- suppressWarnings(as.numeric(xx))
  return(TIMESERIES(xx, START=c({start_year},1), FREQ=1))
}}

lag <- function(xx){{ c(NA, head(xx, -1)) }}

prepare_data <- function(){{
  fname <- '{fname_fixed}'
  myCols <- as.character(read_excel(fname, n_max = 1, col_names = FALSE))
  # Quiet the "New names:" chatter and warnings during read
  df <- suppressMessages(suppressWarnings(
          tryCatch(
            read_excel(fname, skip={skiprows}, col_names=myCols, .name_repair = "minimal"),
            error = function(e) read_excel(fname, skip={skiprows}, col_names=myCols)
          )))

  {str_omit}

  My_modelData = list(
{str_data_args})
  
  return(My_modelData)
}}
"""
    return my_str

def gen_WriteOutput():
    return r"""
write_coeffs <- function(model){

  # --- helper: read COEFF names for a given behavioral eq from model.txt ---
  safe_grep <- function(pat, x) {
    which(grepl(pat, x, perl = TRUE))
  }
  escape_re <- function(s) {
    gsub("([][(){}^$*+?|\\.^-])", "\\\\\\1", s, perl = TRUE)
  }
  get_coeff_names_from_modeltxt <- function(eq_name) {
    lines <- try(readLines("model.txt", warn = FALSE), silent = TRUE)
    if (inherits(lines, "try-error")) return(NULL)

    # find each "BEHAVIORAL> eq_name" block
    pat_start <- paste0("^\\s*BEHAVIORAL>\\s*", escape_re(eq_name), "\\s*$")
    starts <- safe_grep(pat_start, lines)
    if (length(starts) == 0) return(NULL)

    # for each block, look ahead until next header and pick the first COEFF> line
    for (s in starts) {
      # next header (another BEHAVIORAL>, IDENTITY>, or END)
      next_hdr_rel <- safe_grep("^\\s*(BEHAVIORAL>|IDENTITY>|END)", lines[(s+1):length(lines)])
      stop_idx <- if (length(next_hdr_rel) > 0) s + next_hdr_rel[1] else length(lines) + 1
      block <- lines[(s+1):(stop_idx-1)]
      coeff_line <- grep("^\\s*COEFF>", block, value = TRUE)
      if (length(coeff_line) > 0) {
        raw <- sub("^\\s*COEFF>\\s*", "", coeff_line[1])
        toks <- unlist(strsplit(raw, "\\s+"))
        toks <- toks[nzchar(toks)]
        if (length(toks) > 0) return(toks)
      }
    }
    NULL
  }

  df_all_coeffs <- data.frame(
    eq    = character(0),
    param = character(0),
    coeff = numeric(0),
    pval  = numeric(0),
    stringsAsFactors = FALSE
  )

  df_all_stats <- data.frame(
    eq               = character(0),
    `Durbin-Watson`  = numeric(0),
    `Adj. R-Squared` = numeric(0),
    `Est. Technique` = character(0),
    stringsAsFactors = FALSE
  )

  for (name in model$vendogBehaviorals){

    var <- get(name, model$behaviorals)

    coef_vec <- as.numeric(var$coefficients)
    terms    <- names(var$coefficients)

    # choose readable names
    need_terms <- (is.null(terms) || length(terms) != length(coef_vec) || any(is.na(terms) | terms == ""))

    if (need_terms) {
      from_txt <- get_coeff_names_from_modeltxt(name)
      if (!is.null(from_txt) && length(from_txt) == length(coef_vec)) {
        terms <- from_txt
      } else if (!is.null(var$regressors) && !is.null(colnames(var$regressors)) &&
                 length(colnames(var$regressors)) == length(coef_vec)) {
        terms <- colnames(var$regressors)
      } else {
        terms <- paste0("b", seq_along(coef_vec)) # last resort
      }
    }

    # p-values
    pvals <- suppressWarnings(as.numeric(var$statistics$CoeffPvalues))
    if (is.null(pvals) || length(pvals) != length(coef_vec)) {
      pvals <- rep(NA_real_, length(coef_vec))
    }

    df_all_coeffs <- rbind(
      df_all_coeffs,
      data.frame(eq = name, param = terms, coeff = coef_vec, pval = pvals, stringsAsFactors = FALSE)
    )

    stats <- var$statistics
    df_all_stats <- rbind(
      df_all_stats,
      data.frame(
        eq               = name,
        `Durbin-Watson`  = as.numeric(stats$DurbinWatson),
        `Adj. R-Squared` = as.numeric(stats$AdjustedRSquared),
        `Est. Technique` = as.character(stats$estimationTechnique),
        stringsAsFactors = FALSE
      )
    )
  }

  openxlsx::write.xlsx(df_all_coeffs, "df_coeffs.xlsx", rowNames = FALSE)
  openxlsx::write.xlsx(df_all_stats,  "df_stats.xlsx",  rowNames = FALSE)
}
"""


class BimetsModel:

    __PATH = os.getcwd()

    @classmethod
    def set_R_path(cls, new_path):
        cls.__PATH = new_path

    def __init__(self, eqs, data_args=None, extend_args=None, exo_args=None, verbose=False):
        """
        Initialize a model, given the list of flow variables, stock relations and equation definitions
        """
        self.EQS = eqs
        self.DATA_ARGS = data_args or {}
        self.EXTEND_ARGS = extend_args or {}
        self.EXO_ARGS = exo_args or {}

        self.var_names = []  # stores names of behavioral variables

        if verbose:
            print("--------------------------------")
            print("*** NEW BIMETS MODEL ***")
            print("EQS\n", self.EQS, "\n\n")
            print("DATA_ARGS\n", self.DATA_ARGS, "\n\n")
            print("EXTEND_ARGS\n", self.EXTEND_ARGS, "\n\n")
            print("EXO_ARGS\n", self.EXO_ARGS, "\n\n")
            print("-------------------------------")

        # stores model code internally
        self.dataprep_str = ""
        self.write_out_str = ""
        self.model_code_head = ""
        self.model_code_est = ""
        self.model_code_sim_config = ""
        self.model_code_sim = ""

        self.path = None  # stores model path


    def fetch_output(self, filename="df_model.xlsx"):
        if self.path is None:
            raise RuntimeError("No path found. Did you run the model?")
        df = pd.read_excel(os.path.join(self.path, filename))
        try:
            df = df.set_index("Year")
        except Exception:
            pass
        return df

    def read_data(self, path, skiprows=0, omit_nans=False):
        """
        read the observed data from an excel file.

        :param path: path or str, reference to where the data is stored. first row must be the exact variable names
        :param skiprows: number of rows to skip (default 2, for source and unit, otherwise 1 (name))
        """
        path = os.path.abspath(path)
        self.dataprep_str = gen_DataPrep(path, self.DATA_ARGS, skiprows=skiprows, omit_nans=omit_nans)
        self.write_out_str = gen_WriteOutput()

    def _r_eq_list(self, eq):
        """Return a proper R eqList=... line or empty string."""
        if eq is None:
            return ""
        if isinstance(eq, str):
            return f"eqList=c('{eq}'),\n"
        if isinstance(eq, (list, tuple)):
            items = ",".join(f"'{e}'" for e in eq)
            return f"eqList=c({items}),\n"
        raise TypeError("eq must be None, str, list, or tuple")

    def gen_model(self, path, sim_type='STATIC', year_start_est=2001, year_end_est=2022,
                  year_start_sim=2001, year_end_sim=2022, eq=None,
                  JacobianDrop=None, constAdj=None, algo="NEWTON"):
        """
        write 'data_prep.R' and 'model.txt'
        """
        path = os.path.join(os.getcwd(), path)
        self.path = path

        if not os.path.exists(path):
            os.makedirs(path, exist_ok=True)

        with open(os.path.join(path, "data_prep.R"), "w") as file:
            file.write(self.dataprep_str)

        with open(os.path.join(path, "write_output.R"), "w") as file:
            file.write(self.write_out_str)

        # BIMETS MODEL
        model_txt = "MODEL\n\n\n"
        self.var_names = []

        for k, v in self.EQS.items():
            if isinstance(v['type'], str):
                if v["EQ"] is None:
                    # raise RuntimeError(f"Equation {k} is None")
                    warnings.warn(f"Equation {k} is None")
                    continue 
                if v["type"] in ["BEHAVIORAL", "IDENTITY"]:
                    if v["type"] == "BEHAVIORAL":
                        self.var_names.append(k)
                    if "COMMENT" in v and v["COMMENT"]:
                        model_txt += "COMMENT> " + v["COMMENT"] + "\n"
                    model_txt += v["type"] + "> " + k + "\n"
                    model_txt += "EQ> " + "".join(v["EQ"].split("\n")) + "\n"
                    # if "COEFF" in v and v["COEFF"]:
                    #     model_txt += "COEFF> " + str(v["COEFF"]) + "\n"
                    if "COEFF" in v and v["COEFF"]:
                        coeff_str = _fmt_coeff_names(v["COEFF"])
                        if coeff_str:
                            model_txt += f"COEFF> {coeff_str}\n"
                    if "RESTRICT" in v and v["RESTRICT"]:
                        model_txt += "RESTRICT> " + v["RESTRICT"] + "\n"
                    if "CONDITION" in v and v["CONDITION"]:
                        model_txt += "IF> " + v["CONDITION"] + "\n"
                    model_txt += "\n"
            else:
                for i in range(len(v["type"])):
                    if v["EQ"][i] is None:
                        raise RuntimeError(f"Equation {k} is None")
                    if v["type"][i] in ["BEHAVIORAL", "IDENTITY"]:
                        if v["type"][i] == "BEHAVIORAL":
                            self.var_names.append(k)
                        if "COMMENT" in v and i < len(v.get("COMMENT", [])) and v["COMMENT"][i]:
                            model_txt += "COMMENT> " + v["COMMENT"][i] + "\n"
                        model_txt += v["type"][i] + "> " + k + "\n"
                        model_txt += "EQ> " + "".join(v["EQ"][i].split("\n")) + "\n"
                        #if "COEFF" in v and i < len(v.get("COEFF", [])) and v["COEFF"][i]:
                        #    model_txt += "COEFF> " + str(v["COEFF"][i]) + "\n"
                        if "COEFF" in v and i < len(v.get("COEFF", [])) and v["COEFF"][i]:
                            coeff_str = _fmt_coeff_names(v["COEFF"][i])
                            if coeff_str:
                                model_txt += f"COEFF> {coeff_str}\n"

                        if "RESTRICT" in v and i < len(v.get("RESTRICT", [])) and v["RESTRICT"][i]:
                            model_txt += "RESTRICT> " + v["RESTRICT"][i] + "\n"
                        if "CONDITION" in v and i < len(v.get("CONDITION", [])) and v["CONDITION"][i]:
                            model_txt += "IF> " + v["CONDITION"][i] + "\n"
                        model_txt += "\n"

        model_txt += "\nEND"
        with open(os.path.join(path, "model.txt"), "w") as file:
            file.write(model_txt)

        self.model_txt = model_txt

        # R script to run for estimation/simulation
        my_code = r"""
rm(list=ls(all=TRUE))
if(!is.null(dev.list())) dev.off()
cat("\014")

suppressPackageStartupMessages({
  library(mFilter)
  library(bimets)
  library(knitr)
  library(readxl)
  library(patchwork)
  library(tidyr)
  library(readr)
  library(ggplot2)
  library(ggthemes)
  library(scales)
  library(zoo)
  library(gridExtra)
  library(openxlsx)
  library(reshape2)
  library(stringr)
})

source("data_prep.R")
source("write_output.R")

model_txt <- readr::read_file("model.txt")
model <- LOAD_MODEL(modelText = model_txt, showWarnings = TRUE)
data  <- prepare_data()
model <- LOAD_MODEL_DATA(model, data)
"""

        eq_line = self._r_eq_list(eq)  # "" or like "eqList=c('kappa','...'),\n"
        my_code += f"""
model <- ESTIMATE(model,
    {eq_line}TSRANGE=c({year_start_est}, 1, {year_end_est}, 1),
    forceTSRANGE = FALSE,
    tol=1e-15,
    CHOWTEST = FALSE,
    verbose = TRUE,
    estTech = 'OLS'
)

write_coeffs(model)
"""

        # Post-estimation sanity check — catch missing/invalid coefficients before SIMULATE
        my_code += r"""
# ---- Post-estimation sanity check: all behaviorals have valid coefficients?
beh <- model$vendogBehaviorals
bad <- Filter(function(nm) {
  v <- try(get(nm, model$behaviorals), silent = TRUE)
  inherits(v, "try-error") ||
    is.null(v$coefficients) ||
    any(!is.finite(as.numeric(v$coefficients)))
}, beh)

if (length(bad) > 0) {
  openxlsx::write.xlsx(data.frame(eq = bad), "df_bad_behaviorals.xlsx", rowNames = FALSE)
  stop(sprintf("Missing/invalid coefficients for: %s (check data, dummies, and collinearity or rerun ESTIMATE with a specific eqList).",
               paste(bad, collapse = ", ")))
}
"""

        # Diagnostics block with token placeholders (no % formatting!)
        diag_block = r"""
# ---- GENERIC DIAGNOSTICS ----
START_EST_Y <- __YS_EST__
END_EST_Y   <- __YE_EST__

# 0) Per-series diagnostics in the estimation window
suppressWarnings({
  try(openxlsx::write.xlsx(as.data.frame(data), "df_data_pre_est.xlsx"), silent = TRUE)

  diag_series <- function(tsobj, startY, endY) {
    v <- try(window(tsobj, start = c(startY,1), end = c(endY,1)), silent = TRUE)
    v <- suppressWarnings(as.numeric(v))
    finite <- is.finite(v)
    vv <- v[finite]
    c(
      N = length(v),
      NA_or_Inf = sum(!finite),
      zeros = sum(vv == 0, na.rm = TRUE),
      sd = if (length(vv) > 1) stats::sd(vv) else NA_real_,
      min = if (length(vv) > 0) min(vv) else NA_real_,
      max = if (length(vv) > 0) max(vv) else NA_real_
    )
  }

  vars <- names(data)
  diag_mat <- do.call(rbind, lapply(vars, function(nm)
    diag_series(data[[nm]], START_EST_Y, END_EST_Y)
  ))
  diag_df <- as.data.frame(diag_mat)
  diag_df$variable <- vars
  diag_df <- diag_df[, c("variable","N","NA_or_Inf","zeros","sd","min","max")]
  try(openxlsx::write.xlsx(diag_df, "df_diag_est_window.xlsx"), silent = TRUE)

  bad_any <- subset(diag_df, NA_or_Inf > 0 | zeros > 0)
  if (nrow(bad_any) > 0) {
    message("[diag] Variables with NA/Inf or zeros in estimation window:\n",
            paste(capture.output(print(bad_any)), collapse = "\n"))
  }
})

# 1) Parse equations from model_txt
lines    <- unlist(strsplit(model_txt, "\n"))
eq_lines <- sub("^\\s*EQ>\\s*", "", grep("^\\s*EQ>", lines, value = TRUE))

# 2) Denominator candidates: tokens that appear after a '/' in EQ lines
den_pat  <- "/\\s*([A-Za-z_][A-Za-z0-9_]*)"
den_vars <- unique(unlist(lapply(eq_lines, function(ln) {
  m <- stringr::str_match_all(ln, den_pat)[[1]]
  if (is.null(m) || nrow(m) == 0) return(character(0))
  m[,2]
})))
den_vars <- den_vars[den_vars %in% names(data)]

if (length(den_vars) > 0) {
  den_rows <- lapply(den_vars, function(nm){
    v <- try(window(data[[nm]], start = c(START_EST_Y,1), end = c(END_EST_Y,1)), silent = TRUE)
    v <- suppressWarnings(as.numeric(v))
    data.frame(variable = nm,
               zeros = sum(v == 0, na.rm = TRUE),
               NA_or_Inf = sum(!is.finite(v)))
  })
  den_df <- do.call(rbind, den_rows)
  try(openxlsx::write.xlsx(den_df, "df_diag_denominators.xlsx"), silent = TRUE)
  bad_den <- subset(den_df, zeros > 0 | NA_or_Inf > 0)
  if (nrow(bad_den) > 0) {
    message("[diag] Potential denominator issues (zeros/NA in denominators):\n",
            paste(capture.output(print(bad_den)), collapse = "\n"))
  }
}

# 3) LOG() arguments must be > 0
log_pat  <- "(?i)\\bLOG\\s*\\(\\s*([A-Za-z_][A-Za-z0-9_]*)"
log_vars <- unique(unlist(lapply(eq_lines, function(ln){
  m <- stringr::str_match_all(ln, log_pat)[[1]]
  if (is.null(m) || nrow(m) == 0) return(character(0))
  m[,2]
})))
log_vars <- log_vars[log_vars %in% names(data)]

if (length(log_vars) > 0) {
  log_rows <- lapply(log_vars, function(nm){
    v <- try(window(data[[nm]], start = c(START_EST_Y,1), end = c(END_EST_Y,1)), silent = TRUE)
    v <- suppressWarnings(as.numeric(v))
    data.frame(variable = nm,
               nonpositive = sum(v <= 0, na.rm = TRUE),
               NA_or_Inf  = sum(!is.finite(v)))
  })
  log_df <- do.call(rbind, log_rows)
  try(openxlsx::write.xlsx(log_df, "df_diag_log_args.xlsx"), silent = TRUE)
  bad_log <- subset(log_df, nonpositive > 0 | NA_or_Inf > 0)
  if (nrow(bad_log) > 0) {
    message("[diag] LOG() domain issues (nonpositive/NA):\n",
            paste(capture.output(print(bad_log)), collapse = "\n"))
  }
}

# 4) Required burn-in from TSLAG/TSDELTA
lag_pat <- "TSLAG\\s*\\(\\s*([A-Za-z_][A-Za-z0-9_]*)\\s*(,\\s*([0-9]+))?\\s*\\)"
del_pat <- "TSDELTA\\s*\\(\\s*([A-Za-z_][A-Za-z0-9_]*)\\s*(,\\s*([0-9]+))?\\s*\\)"

get_max_k <- function(pat) {
  ks <- unlist(lapply(eq_lines, function(ln){
    m <- stringr::str_match_all(ln, pat)[[1]]
    if (is.null(m) || nrow(m) == 0) return(0)
    vals <- suppressWarnings(as.numeric(m[,4]))
    vals[is.na(vals)] <- 1
    max(vals, na.rm = TRUE)
  }))
  if (length(ks) == 0) 0 else max(ks, na.rm = TRUE)
}
req_burnin <- max(get_max_k(lag_pat), get_max_k(del_pat))
if (is.finite(req_burnin) && req_burnin > 0) {
  message(sprintf("[diag] Minimum burn-in from lags/deltas: %d (consider shifting TSRANGE start)", as.integer(req_burnin)))
}
# ---- END DIAGNOSTICS ----
"""
        # Inject numbers safely
        my_code += diag_block.replace("__YS_EST__", str(year_start_est)).replace("__YE_EST__", str(year_end_est))

        if sim_type is not None:

            for k, v in self.EXTEND_ARGS.items():
                add_code = ""
                for k2, v2 in v.items():
                    if k2 in ["UPTO", "BACKTO"]:
                        continue
                    add_code += ", "
                    if isinstance(v2, str):
                        v2 = f"'{v2}'"
                    add_code += k2 + "=" + str(v2)

                year_end = v.get("UPTO", year_end_sim)
                backto_str = f"BACKTO=c({v['BACKTO']}, 1), " if "BACKTO" in v else ""
                my_code += f"model$modelData${k} = TSEXTEND(model$modelData${k}, {backto_str}UPTO=c({year_end},1){add_code})\n"

            my_code += "exogenizeCandidates  <- list(\n"
            add_code = []
            for k, v in self.EXO_ARGS.items():
                add_code.append(f"\t{k}=c({v[0]},{v[1]},{v[2]},{v[3]})")
            my_code += ",\n".join(add_code) + "\n"
            my_code += ")\n"

            my_code += r"""
vendog_vars <- model$vendog
exogenizeList <- list()
for (var_name in names(exogenizeCandidates)) {
  if (var_name %in% vendog_vars) {
      exogenizeList[[var_name]] <- exogenizeCandidates[[var_name]]
  }
}
"""

            jdropstr = ""
            if JacobianDrop is not None:
                assert isinstance(JacobianDrop, list), "JacobianDrop should be a list of variables."
                jdropstr = "JacobianDrop=c(%s),\n    " % (",".join(["'%s'" % i for i in JacobianDrop]))

            addfact_str = ""
            addfact_args = ""
            if constAdj is not None:
                addfact_args = "\n    ConstantAdjustment=constantAdjList,\n"
                args = []
                for varname, v in constAdj.items():
                    assert "VALUES" in v, f"Argument {varname} must contain key 'VALUES'"
                    assert "START" in v, f"Argument {varname} must contain key 'START'"
                    vals = ",".join([str(i) for i in v["VALUES"]])
                    startyear = str(v["START"])
                    args.append(f"{varname} = TIMESERIES({vals}, START=c({startyear}, 1), FREQ='A')")
                argstr = ",\n".join(args)
                addfact_str += f"constantAdjList <- list(\n{argstr}\n)"

            if not (isinstance(sim_type, str) and sim_type.startswith("RESCHECK+")):

                my_code += f"""
{addfact_str}

model <- SIMULATE(model,
    simType='{sim_type}',
    TSRANGE=c({year_start_sim}, 1, {year_end_sim}, 1),
    simConvergence=1e-15,
    Exogenize=exogenizeList,{addfact_args}
    {jdropstr}simAlgo='{algo}'
)
simu1 <- model$simulation
simu2 <- head(simu1,-1)
df_s <- data.frame(simu2)
Year <- c(list({year_start_sim}:{year_end_sim}))
df_s['Year'] <- Year
# df_s <- data.frame(simu2, row.names = Year)
rownames(df_s) <- {year_start_sim}:{year_end_sim}
openxlsx::write.xlsx(df_s, "df_model.xlsx")
openxlsx::write.xlsx(data.frame(data), "df_data.xlsx")

# write outputs
write_coeffs(model)
"""
            else:
                my_code += f"""

# RESCHECK to get initial tracking residuals (with error AC)
model<-SIMULATE(model,
                TSRANGE=c({year_start_sim},1,{year_end_sim},1),
                simType='RESCHECK',
                ZeroErrorAC=TRUE,
                Exogenize=exogenizeList)
initTrac<-model$ConstantAdjustmentRESCHECK

# dynamic simulation using initTrac as constant adjustments
model<-SIMULATE(model,
                simType='{sim_type.split("+")[1]}',
                TSRANGE=c({year_start_sim},1,{year_end_sim},1),
                ConstantAdjustment=initTrac,
                Exogenize=exogenizeList)

simu1 <- model$simulation
simu2 <- head(simu1,-1)
Year <- c(list({year_start_sim}:{year_end_sim}))
df_s <- data.frame(simu2)
df_s['Year'] <- Year
openxlsx::write.xlsx(df_s, "df_model.xlsx")
openxlsx::write.xlsx(data.frame(data), "df_data.xlsx")

# write outputs
write_coeffs(model)
"""

        with open(os.path.join(path, "bimets_model.R"), "w") as file:
            file.write(my_code)

        self.model_code_est = my_code

    def run(self):
        """
        attempts finding the R directory and running the R script generated.
        """
        output_str = None

        if self.path is None:
            raise RuntimeError("Could not yet run the model. It seems as if you have not yet run gen_model().")

        # print("running model....")
        this_path = os.getcwd()

        # print("this_path", this_path)

        curr_r_path = read_config("R_PATH")
        self.set_R_path(curr_r_path)

        # print("r_path", curr_r_path)
        r_path = os.path.join(self.__class__.__PATH, "Rscript")

        print("[sfctools] R Script: ", r_path)

        try:
            os.chdir(self.path)
            print("changed to path", self.path)
        except Exception as e:
            print("Cannot change to R path:", str(e))

        print("Running bimets_model.R ...")
        try:
            print("subprocess->", r_path)
            result = subprocess.run(
                [r_path, "bimets_model.R"],
                capture_output=True,
                encoding="utf-8",
                text=True,
                check=True
            )
            output_str = result.stdout

        except subprocess.CalledProcessError as e:
            error_message = e.stderr
            os.chdir(this_path)
            raise RuntimeError(f"Error occurred: {error_message}. Did you add R to the PATH environment variables?")

        except Exception as e:
            os.chdir(this_path)
            raise RuntimeError(f"Unexpected error occurred: {e}")

        os.chdir(this_path)
        return output_str

    def print_summary(self, vars=None, latex=False, pandas=False, fix_greeks=True, fix_underscores=True):
        """
        print a comprehensive summary table of the regressions conducted by Bimets
        """
        if isinstance(vars, str):
            vars = [vars]
        elif vars is None:
            vars = self.var_names

        df_coeffs = pd.read_excel(os.path.join(self.path, "df_coeffs.xlsx"))# .set_index("Unnamed: 0")
        df_stats = pd.read_excel(os.path.join(self.path, "df_stats.xlsx"))# .set_index("Unnamed: 0")

        print("coeffs")
        print(df_coeffs)
        print("\n")
        print("stats")
        print(df_stats)
        print("\n")

        rows = []

        def latex_convert(x):
            x = str(x)
            if fix_underscores and ("_" in x and len(x.split("_")) > 1):
                sp = x.split("_")
                x = sp[0] + "_{" + "".join(sp[1]) + "}"

            if pandas:
                if fix_greeks:
                    greeks = ["alpha", "eta", "beta", "gamma", "delta", "epsilon",
                              "mu", "nu", "theta", "vartheta", "xi", "psi", "zeta", "sigma"]
                    for g in greeks:
                        x = re.sub(r'\b' + g + r'\b', r'\\' + g, x)
                        x = re.sub(r'\b' + g.capitalize() + r'\b', r'\\' + g.capitalize(), x)
                if latex:
                    x = "$" + x + "$"
            return x

        for var in vars:
            df_filter = df_coeffs[df_coeffs["eq"] == var][["param","coeff","pval"]]

            rowdata = {"Equation": [var]}
            k = 1
            for _, row in df_filter.iterrows():
                name = row["param"]                       # <-- use param name
                if not latex:
                    if pandas:
                        val = "%s = %.03f" % (name, row["coeff"])
                    else:
                        val = "%s = %.03f" % (name, row["coeff"])
                else:
                    val = "%s = %.03f" % (name, row["coeff"])
                has_star = False
                if row["pval"] <= 0.10:
                    if latex:
                        has_star = True
                        val += "^{*"
                    else:
                        val += "*"
                if row["pval"] <= 0.05: val += "*"
                if row["pval"] <= 0.01: val += "*"
                if latex and has_star: val += "}"
                rowdata["Parameter %i" % k] = [val]
                k += 1

            rows.append(pd.DataFrame(rowdata))
        all_rows = pd.concat(rows).reindex().set_index("Equation")

        if not latex:
            if pandas:
                return all_rows
            out_str = all_rows.to_string().replace("NaN", "")
            row_str = ""
            N = len(out_str.split("\n")[0])
            sepstr1 = "\n" + "_" * N + "\n"
            sepstr2 = "\n" + "-" * N + "\n"
            row_str = sepstr1 + sepstr2.join(out_str.split("\n"))
            row_str += sepstr1
            row_str = row_str.strip()

            final_str = ""
            for i, row in enumerate(row_str.split("\n")):
                if i == 0:
                    final_str += " " + row + " " * (N - len(row)) + " \n"
                else:
                    final_str += "|" + row + " " * (N - len(row)) + "|\n"
            final_str += "\n* p < 0.1, ** p < 0.05, *** p < 0.01\n"
            print(final_str)
            return final_str
        else:
            all_rows = all_rows.applymap(lambda x: "$" + str(x) + "$")
            out_str = all_rows.to_latex().replace("NaN", "")
            for i in ["alpha", "gamma", "beta", "delta", "epsilon", "rho", "sigma", "lambda", "nu", "mu"]:
                out_str = out_str.replace(i, "\\" + i)
            out_str = out_str.replace("$nan$", "")
            print(out_str)
            return out_str


class bcolors:
    HEADER = '\033[95m'
    OKBLUE = '\033[94m'
    OKCYAN = '\033[96m'
    OKGREEN = '\033[92m'
    WARNING = '\033[93m'
    FAIL = '\033[91m'
    ENDC = '\033[0m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'


class Equations(Singleton):

    def __init__(self):
        if hasattr(self, "initialized"):  # only initialize once
            return
        self.initialized = True

        self._eq_dict = {}               # dict of equations {eq_name: (eq_name, rhs_str)}
        self._coeff_dict = {}            # dict of key-value pairs for coefficients
        self._var_links_getter = {}      # dict of linkages between symbols and getters
        self._var_links_setter = {}      # dict of linkages between symbols and setters

        self.data = None
        self.exo_dict = {}

        self.namespace = None

        # caches
        self._compiled_exprs = {}        # varname -> compiled code
        self._symbols_needed = {}        # varname -> set(symbols)
        self._ident_rx = re.compile(r'\b[A-Za-z_]\w*\b')

    @property
    def coeff_dict(self):
        return self._coeff_dict

    @property
    def eq_dict(self):
        return self._eq_dict

    def set_namespace(self, namespace):
        self.namespace = namespace

    def init_paths(self, data_path, model_path, fname_model="model.txt", fname_coeffs="df_coeffs.xlsx"):
        """
        initialize paths and the equation system based on the
        BimetsModel output files from a previous calibration step
        """
        self.load_data(data_path)
        self.read_bimets(os.path.join(model_path, fname_model))
        self.set_coeffs_from_excel(os.path.join(model_path, fname_coeffs))

    def add(self, eq, coeffs=None, eq_name=None, namespace=None):
        """
        add a new equation ("LHS = RHS")
        """
        if len(eq.split("=")) >= 2:
            if eq_name is None:
                eq_name = eq.split("=")[0].strip()
            if namespace:
                self.link_attr(eq_name, namespace)
            self._eq_dict[eq_name] = (eq_name, eq.split("=")[1].strip())
            # invalidate caches for this var
            self._compiled_exprs.pop(eq_name, None)
            self._symbols_needed.pop(eq_name, None)
        else:
            raise RuntimeError("eq not properly defined (= not found)")

        if coeffs is not None:
            self._coeff_dict.update(coeffs)

    def read_bimets(self, fname, verbose=False, namespace=None):
        """add equations from a R-Bimets model file (text format)"""
        if verbose:
            print("loading", fname, "...")
        with open(fname, "r") as file:
            lines = file.readlines()
            for line in lines:
                if line.startswith("EQ>"):
                    try:
                        mydef = line.split("EQ>")[1].strip()
                        eq_name = line.split("EQ>")[1].split("=")[0].strip()
                        if verbose:
                            print("....", mydef, bcolors.OKGREEN + "[OK]" + bcolors.ENDC)
                        self.add(mydef, namespace=namespace, eq_name=eq_name)
                    except Exception as e:
                        if verbose:
                            print("....", mydef, bcolors.FAIL + " [ERROR]" + bcolors.ENDC)
                            print("      > ", str(e))
        if verbose:
            print("done.")

    def link(self, var_name, ref_val):
        """see register_variable"""
        return self.register_variable(var_name, ref_val)

    def register_variable(self, var_name, ref_val):
        """register an equation variable link"""
        self._var_links_getter[var_name] = ref_val

    def read_data(self, data_or_path, format="xlsx", index="Year"):
        """
        read data directly to the Equations system.
        """

        if isinstance(data_or_path, pd.DataFrame):
            df = data_or_path
        else:  
            if format == "xlsx":
                df = pd.read_excel(data_or_path).set_index(index)
            elif format == "csv":
                df = pd.read_csv(data_or_path).set_index(index)
            else:
                raise RuntimeError("Please use excel or csv format. Other formats are not supported!")
        
        try:
            df = df.set_index(index)
        except Exception:
            pass

        if not self.data:
            self.data = {}

        for colname in df.columns:
            self.data[colname] = np.array(df[colname])

    def link_namespace_dict(self, namespace, filter=None):
        """
        adds a certain namespace to the setter/getter mechanism
        """
        for k, v in namespace.__dict__.items():
            if (not filter) or k in filter:
                self.link_attr(k, namespace)

    def extract_original_varname(self, varname):
        """
        Detect wrapped expressions like (LOG(...) or TSDELTA(...) etc.)
        returns original_varname, detected
        """
        log_match = re.match(r"(Log|LOG|log)\((.*)\)", varname)
        tsdelta_match = re.match(r"TSDELTA\((.*)\)", varname)
        tsdeltap_match = re.match(r"TSDELTAP\((.*)\)", varname)

        original_varname = varname
        detected = -1

        if log_match:
            original_varname = log_match.group(2).strip()
            detected = 0
        elif tsdelta_match:
            original_varname = tsdelta_match.group(1).strip()
            detected = 1
        elif tsdeltap_match:
            original_varname = tsdeltap_match.group(1).strip()
            detected = 2

        return original_varname, detected

    def link_attr(self, varname, obj, attr_name=None, transfer_data=True):
        """
        register an attribute link to a python object, e.g. an sfctools-Agent
        """
        rename = attr_name is not None
        original_varname, detected = self.extract_original_varname(varname)
        if not rename:
            attr_name = original_varname

        if detected == 0:  # LOG(...): keep special LHS setter; getter returns log(series)
            self._var_links_getter[varname] = (lambda obj=obj, an=attr_name:
                                               np.log(getattr(obj, an)))
            def f_setter(an, x, t):
                getattr(obj, an).__setitem__(t, np.exp(x))
            self._var_links_setter[varname] = lambda x, t, an=attr_name: f_setter(an, x, t)

        elif detected == 1:  # TSDELTA(...): make getter cheap (base array); keep smart setter
            self._var_links_getter[varname] = (lambda obj=obj, an=attr_name:
                                               getattr(obj, an))
            def f_setter(an, x, t):
                x_prev = getattr(obj, an)[t - 1]
                getattr(obj, an).__setitem__(t, x + x_prev)
            self._var_links_setter[varname] = lambda x, t, an=attr_name: f_setter(an, x, t)

        elif detected == 2:  # TSDELTAP(...): getter = base array; setter = % change to level
            self._var_links_getter[varname] = (lambda obj=obj, an=attr_name:
                                               getattr(obj, an))
            def f_setter(an, pct_change, t):
                x_prev = getattr(obj, an)[t - 1]
                getattr(obj, an).__setitem__(t, x_prev * (1 + pct_change/100.0))
            self._var_links_setter[varname] = lambda x, t, an=attr_name: f_setter(an, x, t)

        else:  # plain variable
            self._var_links_getter[varname] = (lambda obj=obj, an=attr_name:
                                               getattr(obj, an))
            def f_setter2(an, x, t):
                getattr(obj, an).__setitem__(t, x)
            self._var_links_setter[varname] = lambda x, t, an=attr_name: f_setter2(an, x, t)

        # transfer known data to the object
        if transfer_data:
            if not self.data:
                self.data = {}
            if original_varname in self.data:
                try:
                    dat = np.array(list(self.data[original_varname].values()))
                except Exception:
                    dat = np.array(self.data[original_varname])
                try:
                    old_values = getattr(obj, attr_name)
                    n = min(len(dat), len(old_values))
                    new_values = np.copy(old_values)
                    new_values[:n] = dat[:n]
                    if len(old_values) > len(dat):
                        new_values[n:] = new_values[n-1]
                    setattr(obj, attr_name, new_values)
                except Exception:
                    pass

    def _referenced_symbols(self, varname, rhs):
        syms = self._symbols_needed.get(varname)
        if syms is not None:
            return syms
        ids = set(self._ident_rx.findall(rhs))
        builtin_fn = {
            "TSDELTA","TSDELTAP","TSDELTALOG","TSLAG","MOVAVG","LOG","EXP",
            "log","exp","abs","max","min","sum","SQRT","np","nan","True","False","None"
        }
        pruned = {s for s in ids if s not in builtin_fn}
        self._symbols_needed[varname] = pruned
        return pruned

    def eval(self, varname, t=None, verbose=False):
        if t is None:
            t = Clock().get_time()
        rhs = self.eval_rhs(varname, t, verbose=verbose)
        if np.isnan(rhs):
            raise ValueError(f"NaN occurred in evaluation of '{varname}' (t={t})")
        if verbose:
            _ = self._var_links_getter[varname]()[t]
        self._var_links_setter[varname](rhs, t)
        if verbose:
            _ = self._var_links_getter[varname]()[t]

    def set_value(self, var_name, value, t=None):
        self._var_links_setter[var_name](value, t)

    def eval_rhs(self, varname, t=None, verbose=False):
        """Evaluate the right-hand-side of an equation (context-based, symbol-filtered)."""
        if verbose:
            print("eval rhs of %s at t=" % varname, t)
        if t is None:
            t = Clock().get_time()

        # ---- Vector helpers (vectorized; keep semantics) ----
        def TSDELTA(x, i=1):
            x = np.asarray(x, dtype=float)
            out = np.empty_like(x)
            out[:i] = np.nan
            out[i:] = x[i:] - x[:-i]
            return out

        def TSDELTAP(x, i=1):
            x = np.asarray(x, dtype=float)
            out = np.empty_like(x)
            out[:i] = np.nan
            out[i:] = 100.0 * (x[i:] - x[:-i]) / x[:-i]
            return out

        def TSLAG(x, i=1):
            x = np.asarray(x, dtype=float)
            out = np.empty_like(x)
            out[:i] = np.nan
            out[i:] = x[:-i]
            return out

        def LOG(x): return np.log(x)
        def EXP(x): return np.exp(x)
        def log(x): return np.log(x)
        def exp(x): return np.exp(x)

        def MOVAVG(x, L=1, DIRECTION="BACK", ignoreNA=False, avoidCompliance=False):
            x = np.asarray(x, dtype=float)
            out = np.full_like(x, np.nan, dtype=float)
            if L < 1:
                raise ValueError("L must be >= 1")
            if DIRECTION not in ["AHEAD", "CENTER", "BACK", None]:
                raise ValueError("DIRECTION must be 'AHEAD', 'CENTER', or 'BACK'"
                                 )
            def m(arr): return np.nanmean(arr) if ignoreNA else (np.mean(arr) if not np.isnan(arr).all() else np.nan)
            if DIRECTION == "AHEAD":
                for j in range(len(x) - L + 1):
                    out[j] = m(x[j:j+L])
            elif DIRECTION == "CENTER":
                half = L // 2
                for j in range(half, len(x) - half):
                    out[j] = m(x[j-half:j+half+1])
            else:
                for j in range(L - 1, len(x)):
                    out[j] = m(x[j-L+1:j+1])
            return out

        def TSDELTALOG(x, L=1, avoidCompliance=False):
            x = np.asarray(x, dtype=float)
            out = np.full_like(x, np.nan, dtype=float)
            if L < 1:
                raise ValueError("L must be >= 1")
            for j in range(L, len(x)):
                if x[j - L] > 0 and x[j] > 0:
                    out[j] = np.log(x[j]) - np.log(x[j - L])
            return out

        if varname not in self._eq_dict:
            raise KeyError(f"Could not find expression '{varname}' in the list of equations.")
        _, eq = self._eq_dict[varname]

        # ---- Build locals context (arrays & scalars) — only what's needed ----
        ctx = {
            "TSDELTA": TSDELTA, "TSDELTAP": TSDELTAP, "TSDELTALOG": TSDELTALOG,
            "TSLAG": TSLAG, "MOVAVG": MOVAVG, "LOG": LOG, "EXP": EXP,
            "log": np.log, "exp": np.exp,
            "abs": abs, "max": max, "min": min, "nan": np.nan, "sum": sum,
            "SQRT": np.sqrt, "np": np,
        }

        needed = self._referenced_symbols(varname, eq)

        # coefficients (scalars)
        for name in needed.intersection(self._coeff_dict.keys()):
            ctx[name] = self._coeff_dict[name]

        # exogenous variables (arrays)
        for name in needed.intersection(self.exo_dict.keys()):
            ctx[name] = np.asarray(self.exo_dict[name])

        # linked variables (arrays or scalars)
        for name in needed.intersection(self._var_links_getter.keys()):
            getter = self._var_links_getter[name]
            try:
                v = getter()
            except Exception:
                v = getter
            ctx[name] = np.asarray(v) if not np.isscalar(v) else v

        # optional namespace exposure (only needed names)
        if self.namespace is not None:
            for name in needed:
                if hasattr(self.namespace, name):
                    ctx[name] = getattr(self.namespace, name)

        # ---- Compile once per equation ----
        code = self._compiled_exprs.get(varname)
        if code is None:
            code = compile(eq, f"<eq:{varname}>", "eval")
            self._compiled_exprs[varname] = code

        # ---- Safe eval (sandboxed) ----
        try:
            out = eval(code, {"__builtins__": {}}, ctx)
        except Exception as e:
            raise RuntimeError(f"Could not process expression:\n{eq}\n{e}")

        # ---- scalar vs vector result ----
        if isinstance(out, (list, tuple)):
            out = np.asarray(out)

        if np.isscalar(out) or (isinstance(out, np.ndarray) and out.ndim == 0):
            return float(out) if isinstance(out, np.generic) else out

        return out[t]

    def set_coeffs_from_excel(self, fname: str):
        df = pd.read_excel(fname)
        if {"param", "coeff"}.issubset(df.columns):
            coeffs = dict(zip(df["param"].astype(str), df["coeff"]))
        elif "coeff" in df.columns and df.columns[0] != "coeff":
            coeffs = dict(zip(df.iloc[:, 0].astype(str), df["coeff"]))
        else:
            raise ValueError("Cannot parse coefficients file: expected columns 'param' and 'coeff'.")
        for k, v in coeffs.items():
            self.set_coeff(k, v)

    def set_coeffs(self, coeff_dict):
        """add coefficient name-value pairs from a dictionary"""
        for k, v in coeff_dict.items():
            self.set_coeff(k, v)

    def set_coeff(self, coeff_name, coeff_val):
        """set a coefficient value"""
        self._coeff_dict[coeff_name] = coeff_val

    def load_df(self, df, verbose=False):
        """
        apply data from a pandas DataFrame
        """
        for c in df:
            if verbose:
                notif_str = " ".join(["Apply", str(c), "<-", str(list(df[c])[:5]), "..."])
            if c in self._var_links_setter:
                for t, x in enumerate(df[c]):  # could be optimized with vector assign if needed
                    self._var_links_setter[c](x, t)
                if verbose:
                    notif_str += "  " + bcolors.OKGREEN + "[OK]" + bcolors.ENDC
                    print(notif_str)
            else:
                if verbose:
                    notif_str += "  " + bcolors.FAIL + "[NOT FOUND]" + bcolors.ENDC
                    print(notif_str)

    def load_data(self, data_path, index="Year"):
        """load data from an excel file"""
        df = pd.read_excel(data_path).set_index(index)
        self.data = df.to_dict()

    def exogenize(self, names):
        if isinstance(names, str):
            names = [names]
        for name in names:
            if name not in self.exo_dict:
                data = self.data[name]
                if isinstance(data, pd.Series):
                    data = list(data)
                self.exo_dict[name] = data

def model_txt_to_dict(file_path):
    """converts a bimets model text file to a dictionary for later processing in BimetsModel"""
    equations = defaultdict(lambda: {"type": "", "EQ": "", "COEFF": ""})
    current_context = None
    eq_name = None

    with open(file_path, 'r') as file:
        for line in file:
            line = line.strip()

            if line.startswith("IDENTITY>"):
                current_context = "IDENTITY"
                eq_name = line.split(">")[1].strip()

            elif line.startswith("BEHAVIORAL>"):
                current_context = "BEHAVIORAL"
                eq_name = line.split(">")[1].strip()

            elif line.startswith("EQ>"):
                if eq_name:
                    equations[eq_name]["EQ"] = line.split(">")[1].strip()
                    equations[eq_name]["type"] = current_context

            if current_context == "BEHAVIORAL" and line.startswith("COEFF>"):
                coeff = line.split(">")[1].strip()
                if eq_name:
                    equations[eq_name]['COEFF'] = coeff if coeff else ''

    return equations
