
from scipy.stats import norm
import scipy.stats
from sympy import *
from scipy.optimize import approx_fprime
import networkx as nx
from collections.abc import Mapping
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
import sys
import scipy
from itertools import product
from pyDOE import lhs
from multiprocessing import Lock, Pool, cpu_count, Value, Manager
from itertools import repeat
import itertools
from scipy.stats import chi2, norm
from scipy.optimize import minimize
import matplotlib.pyplot as plt
from .hierarchy_pos import hierarchy_pos
import warnings

"""
This is a module for multinomial nested logit models (MNL models)
with arbitrary nesting. A combined brute-force and scipy optimize algorithm
can be used to fit the parameters to data.

NOTE highly experimental. no warranty given.
"""


def calculate_nagelkerke_r2(full_L, null_L, n):
    try:
        # Ensure values are valid
        if n < 0:
            raise ValueError("n<0")

        if full_L <= 0 or null_L <= 0:
            # raise ValueError("Log-likelihoods and sample size must be positive.")
            x = -min(full_L, null_L) + 1
            full_L += x
            null_L += x

        # print("full_L", full_L, "null_L", null_L)
        # Calculate the components of Nagelkerke's R²
        ratio = full_L / null_L
        if np.isnan(ratio) or np.isinf(ratio):
            raise ValueError("Invalid value encountered in ratio calculation.")

        ratio_power = ratio ** (2 / n)
        if np.isnan(ratio_power) or np.isinf(ratio_power):
            raise ValueError(
                "Invalid value encountered in ratio power calculation.")

        null_power = null_L ** (2 / n)
        if np.isnan(null_power) or np.isinf(null_power):
            raise ValueError(
                "Invalid value encountered in null power calculation.")

        # Calculate Nagelkerke's R²
        nagelkerke_r2 = (1 - ratio_power) / (1 - null_power)

        if np.isnan(nagelkerke_r2) or np.isinf(nagelkerke_r2):
            raise ValueError(
                "Invalid value encountered in Nagelkerke's R² calculation.")

        return nagelkerke_r2

    except (ValueError, ZeroDivisionError) as e:
        # Handle invalid values
        print(f"Error in calculation: {e}")
        return None


def hessian_numeric(func, x0, args):
    n = len(x0)
    hessian = np.zeros((n, n))
    epsilon = np.sqrt(np.finfo(float).eps)

    for i in range(n):
        x0_i = np.copy(x0)
        x0_i[i] += epsilon
        grad1 = (func(x0_i, *args) - func(x0, *args)) / epsilon

        for j in range(n):
            x0_j = np.copy(x0)
            x0_j[j] += epsilon
            grad2 = (func(x0_j, *args) - func(x0, *args)) / epsilon
            # print("grad1, grad2, epsilon", grad1, grad2, epsilon)
            hessian[i, j] = (grad1 - grad2) / epsilon
            # print("hessian", hessian[i, j])
    return hessian


def f_print_stats(param_names, result, alpha=0.05, f=None, args=None):
    """
    Print statistics from fit results

    :param param_names: names of the parameters
    :param result: optimization result instance (scipy.minimize)
    :param alpha: significance level for z statistic (default 0.05)
    """
    assert 0 <= alpha <= 1
    mle_params = result.x

    hessian = None
    if hasattr(result, 'hess') and (result.hess is not None):
        hessian = result.hess
    elif hasattr(result, 'hess_inv') and (result.hess_inv is not None):
        hessian = np.linalg.inv(result.hess_inv.todense())

    if (hessian is None) and (f is not None):
        # from autograd import hessian as autohess

        # def f_wrapped(x):
        #     return f(x, *args)
        # hess_fn = autohess(f_wrapped)
        # hessian = hess_fn(result.x)

        # TODO check hessian_numeric to replace autograd module in future version
        # raise NotImplementedError(
        #     "Please use an algorithm which uses a hessian.")
        # hessian = hessian_numeric(f, x0=result.x, args=args)
        pass

        # print("hessian", hessian)

    results_df = None
    if hessian is not None:
        epsilon = 1e-9

        hessian_inv = np.linalg.inv(
            hessian + np.eye(len(mle_params)) * epsilon)
        standard_errors = np.sqrt(np.diag(hessian_inv))

        z = scipy.stats.norm.ppf(1.0 - alpha / 2.0)
        ci_lower = mle_params - z * standard_errors
        ci_upper = mle_params + z * standard_errors
        z_scores = mle_params / standard_errors
        t_stats = z_scores  # For large samples, t ≈ z

        p_values = 2 * (1 - norm.cdf(np.abs(t_stats)))

        results_df = pd.DataFrame({
            'Parameter': param_names,
            'MLE': mle_params,
            'Standard Error': standard_errors,
            '95% CI Lower': ci_lower,
            '95% CI Upper': ci_upper,
            'z-score': z_scores,
            't-statistic': t_stats,
            'P-value': p_values
        })

    else:
        results_df = pd.DataFrame({
            'Parameter': param_names,
            'MLE': mle_params,
        })
    print(results_df)
    return results_df


class ParamVector:
    # a parameter vector which can be 'flattened' and 'de-flattened'

    def __init__(self, coeffs, constants, mus):

        self.coeffs = self.untangle(list(coeffs.values()))
        self.const = self.untangle(list(constants.values()))
        self.mus = self.untangle(list(mus.values()))
        self.names = [str(i) for i in self.coeffs + self.const + self.mus]

        # print("coeffs", self.coeffs)

    def __repr__(self):
        return f"_______________\nParameterVector:\n| Feature coefficients: {self.coeffs}\n| Alternative-specific const.: {self.const}\n| Scale parameters (mu): {self.mus}\n________"

    def untangle(self, l, l_out=None):
        # print("untangle", l)

        if not l_out:
            l_out = []
        for li in l:
            if isinstance(li, list):
                l_out = self.untangle(li, l_out)
                # l_out = list(set(l_out))
            else:
                if li not in l_out:
                    l_out.append(li)
        return l_out

    @property
    def count(self):
        return len(list(set(self.coeffs))) + len(list(set(self.const))) + len(list(set(self.mus)))

    def to_serial(self):
        # serialize the coefficients to a list
        cf = self.coeffs
        ct = self.const
        ms = self.mus
        # print("to serial", cf, ct, ms)
        return list(cf) + list(ct) + list(ms)

    def set_values(self, vals):
        # assign new values (vals) as list

        k = len(self.coeffs)
        m = len(self.const)
        n = len(self.mus)
        x = vals[0:0+k]
        y = vals[k:k+m]
        z = vals[k+m:k+m+n]

        assert len(x) == len(self.coeffs)
        self.coeffs = x
        assert len(y) == len(self.const)
        self.const = y
        assert len(z) == len(self.mus)
        self.mus = z


class MyLambdify:
    # a lambdify class that can be pickled

    def __init__(self, symbols, expr):
        self.expr = expr
        self.symbols = symbols
        self.func = lambdify(symbols, expr, modules='numpy')

    def __call__(self, *args):

        with warnings.catch_warnings():
            warnings.filterwarnings("ignore")
            return self.func(*args)

    def __getstate__(self):
        return self.expr, self.symbols

    def __setstate__(self, state):
        expr, symbols = state
        self.__init__(symbols, expr)

    @property
    def __code__(self):
        return self.func.__code__


class MultinomialNestedLogitModel:
    """
    Multinomial nested logit
    NOTE highly experimental, no warranty
    """

    def __init__(self, nesting_structure: dict, numeric_indexing=False, equal_params: dict = None):
        """instantiates a model object by parsing the nesting structure

        Args:
            nesting_structure (dict): dict of dicts of lists, nesting structure
        """
        #
        self.nesting = nesting_structure
        self.sep_str = "_"
        self.param_vector = None    # list of parameters of the model
        self.attribute_names = None  # list of attributes in the data
        # create model based on nesting structure and restrictions about equal parameters
        if equal_params is None:
            equal_params = {}
        assert isinstance(equal_params, dict), "Please provide a dict"
        self.create_model(numeric_indexing, equal_params)
        self.build_tree()   # for networkx visualization

    def create_model(self, numeric_indexing=False, equal_params=None):
        """
        Create the model from the given nesting structure

        :param numeric_indexing: if True, numeric indexing is used, i.e.
            beta_somechoice_0, beta_somechoice_1, ...
        :param equal_params: here, a dict can be passed which contains each a parameter names as key and another parameter name as value.
                                These parameters will then be set equal to the parameters in the keys (optional, default None).
                                For example {beta_1: beta_0, beta_2: beta_0} will enforce beta_0 = beta_1, beta_2 and beta_1, beta_2 will be renamed to beta_0.
        """

        # store parameters
        mus = {}
        constants = {}
        coefficients = {}

        # store utilities
        utilities = {}
        exp_utilities = {}

        first_asc = True

        # 1. create symbolic model

        self.all_coeffs = {}

        def new_sym(name):
            """ 
            try to create a new symbol, but when it already exists, return the pre-existing one 
            :return: symbol, is_new 
            """
            if equal_params is not None:  # replace if not supposed to be new parameter
                if name in equal_params:
                    name = equal_params[name]

            if name in self.all_coeffs:
                # print(".... found", name, "already")
                return self.all_coeffs[name], False

            # print("create new symbol named", name)
            # print(self.all_coeffs)

            new_symbol = symbols(name)
            self.all_coeffs[name] = new_symbol
            return new_symbol, True

        def parse_nest(nest, parent=''):
            nonlocal first_asc

            # if parent == '':
            #    mu = symbols('mu')
            #    mus['root'] = mu

            for key, value in nest.items():
                current_key = f'{parent}_{key}' if parent else str(key)

                if isinstance(value, dict):
                    mu = symbols(f'mu_{current_key}')  # scaling parameters
                    mus[current_key] = mu

                    parse_nest(value, current_key)
                else:

                    if len(value) > 0:

                        # construct
                        all_coeff_symbols = []
                        for j, v in enumerate(value):
                            if numeric_indexing:
                                sym_name = f'beta_{current_key}_{j}'
                            else:
                                sym_name = f'beta_{current_key}_{v}'
                            sym, is_new = new_sym(sym_name)
                            all_coeff_symbols.append(sym)

                        # construct feature vector
                        x = symbols(f'{" ".join(value)}')
                        if not isinstance(x, tuple):
                            x = (x,)
                    else:
                        all_coeff_symbols = []
                        x = []

                    # store coefficients
                    # print("coefficients of", current_key,
                    #      ":", all_coeff_symbols)
                    coefficients[current_key] = all_coeff_symbols

                    # aslternative specific constants
                    asc = symbols(f'asc_{current_key}')
                    if first_asc:
                        asc = 0
                        first_asc = False
                    else:
                        constants[current_key] = asc

                    nest_key = self.sep_str.join(
                        current_key.split(self.sep_str)[:-1])
                    search_key = nest_key

                    mu = mus.get(search_key, 1)
                    V = asc + sum(all_coeff_symbols[i] * x[i]
                                  for i in range(len(value)))

                    utilities[current_key] = V
                    exp_utilities[current_key] = exp(mu * V)

        parse_nest(self.nesting)
        denominator = 1e-10 + sum(exp_utilities.values())
        self.probabilities = {
            alt: exp_utilities[alt] / denominator for alt in exp_utilities}

        # print("coefficients:")
        # for k, v in coefficients.items():
        #    print(k, v, id(v))
        #    for vi in v:
        #        print(id(vi))

        # 2. construct linearized ParamVector and list of attributes
        self.param_vector = ParamVector(coefficients, constants, mus)

        # print("ParamVector", self.param_vector)
        self.attribute_names = self.extract_attributes()

        # 3. convert the symbolic expression to numpy function for better performance
        self.lambdify_model()

        for l_i in list(self.probabilities_func.keys()):
            assert isinstance(l_i, str), "Alternative must be specified as str"

    def extract_attributes(self):
        # retrieve a list of attributes from the nesting structure

        attributes = set()

        def parse_nest(nest):
            for _, value in nest.items():
                if isinstance(value, dict):
                    parse_nest(value)
                else:
                    for attr in value:
                        attributes.add(attr)
        parse_nest(self.nesting)
        return list(attributes)

    def lambdify_model(self):
        # lambdify the model from symbolic to numpy function
        params = self.param_vector

        self.attr_symbols = [symbols(attr) for attr in self.attribute_names]
        # print("attr_symbols", self.attr_symbols,
        #      "attribute names", self.attribute_names)

        self.n_parameters = len(self.param_vector.to_serial())
        self.n_mus = len(self.param_vector.mus)
        all_symbols = self.param_vector.to_serial() + self.attr_symbols

        # print("lambdify -> all symbols", all_symbols)

        # print("probabilities", self.probabilities.items())

        # self.probabilities_func = {k: MyLambdify(
        #     all_symbols, p) for k, p in self.probabilities.items()}

       # print("\n\n")
        self.probabilities_func = {}
        for k, p in self.probabilities.items():
            # print("add lambdify", k, "all symbols", all_symbols, "p", p)
            # necessary = self.flat_coefs[k]
            # print("necessary_symbols", necessary)
            # print("\n")
            self.probabilities_func[k] = MyLambdify(all_symbols, p)
        # print("\n\n")
        # print("Created MNL model with symbols", all_symbols)

    def build_tree(self):
        graph_data = {"root": self.nesting}
        G = nx.DiGraph()
        q = list(graph_data.items())
        while q:
            v, d = q.pop()
            for nv, nd in d.items():
                G.add_edge(v, nv)
                if isinstance(nd, Mapping):
                    q.append((nv, nd))
        self.G = G
        if not nx.is_tree(self.G):
            raise RuntimeError("The provided nesting is not a tree.")

    def visualize_tree(self):

        G = self.G
        plt.figure(figsize=(4, 4))
        np.random.seed(8)
        pos = hierarchy_pos(G, "root")
        nx.draw(G, with_labels=True, pos=pos,
                node_color="lightgray", node_size=400)
        plt.show()

    def set_parameters(self, params):
        # set new parameters to the model

        if isinstance(params, list) or isinstance(params, np.ndarray):
            self.param_vector.set_values(list(params))

        elif isinstance(params, dict):
            old_params = self.param_vector.to_serial()
            new_params = np.zeros(self.param_vector.count)
            for k, v in params.items():
                idx = old_params.index(symbols(k))
                new_params[idx] = v
            self.param_vector.set_values(list(new_params))
        else:
            raise RuntimeError("Could not detect params format")

    def predict(self, data, noise=0.0):
        # predict new values based on data

        if self.param_vector is None:
            raise ValueError(
                "Model parameters must be set before making predictions.")

        attributes = data[[c for c in self.attribute_names if c != "choice"]].astype(
            float)
        prob_df = pd.DataFrame(index=attributes.index,
                               columns=self.probabilities_func.keys(),)
        v = attributes.T.values  # shape (n_attr , n_obs)
        params = np.array(self.param_vector.to_serial())  # shape (n_params, )

        # print("predict -----")
        # print("param vector", self.param_vector)
        # print("params", params)
        # print("-----")

        n = v.shape[1]
        for alt in self.probabilities_func:
            if noise > 0:
                y_noise = np.random.normal(0.0, noise, n)
                prob_df[alt] = self.probabilities_func[alt](
                    *params, *v) + y_noise
            else:
                prob_df[alt] = self.probabilities_func[alt](*params, *v)
        s_prob = np.sum(prob_df, axis=1)
        for alt in self.probabilities_func:
            prob_df[alt] /= s_prob
        predictions = np.array(prob_df.idxmax(axis=1).to_list())
        return predictions

    def print_params(self, params=None):
        names = self.param_vector.names
        if params is None:
            params = self.param_vector.to_serial()
        max_length = max(len(name) for name in names)
        for i in range(len(params)):
            print(f"{names[i]:<{max_length}} = {params[i]:.3f}")

    def fit(self, df, regularization='l2', alpha=0.001, alpha_2=0.001, num_samples_per_param=3, bounds=None,
            sampling_method='brute', parallel=False, n_cpus=None, pred_noise=0.0, scale_attrs=True, rel_batch_size=0.2, tol=None, x0=None,
            optimize_globally=False, opt_method="L-BFGS-B", disp=False):
        """
        Fits the model to the data. First, do a brute-force search with num_samples_per_param (alternatively latin hypercube search)
        and then optimize the local parameters using scipy's optimization methods (default BFGS)

        Args:
            df (pandas DataFrame): a dataframe containing the data with a column 'choice'
            regularization (str, optional): regularization technique. Defaults to 'l2'.
            alpha, alpha_2 (float, optional): regularization parameter. Defaults to 0.01. 0 means no regularization
                    for 'l1' (Lasso) regularization, use alpha_1
                    for 'l2' (Ridge) regularization, use alpha_2
                    for 'elstic_net', both alpah_1 and alpha_2 are required

            num_samples_per_param (int, optional): number of samples. Defaults to 3.
            bounds (list, optional): list of tuples, parameter bounds. Use attribute_names to see the order of names. Defaults to None.
            sampling_method (str, optional): 'brute' (brute force) or 'lhs' (latin hypercube). Defaults to 'brute'.
            parallel (bool, optional): use parallel processes. Defaults to False.
            n_cpus (int, optional): number of cpus to use if parallel=True. Defaults to number of cpus-2. NOT USED IN THIS VERSION
            pred_noise (float): predictor noise when fitting (default 0). Set greater than zero for better robustness
            scale_attrs (bool): scale the attributes (features) using StandardScaler
            rel_batch_size (float): batch size for minibatches in fitting algorithm 
            tol (float or None): tolerance of the scipy.optimize local search
            optimize_globally (bool, default False): do a global optimization before local optimization 
            x0 (array-like or None): initial guess of parameters (default None: take currently set parameters)
            opt_method (str): any opt_method supported by scipy.optimize.minimze 
            disp (bool): display output of local optimization (default False)

        Returns:
            result, stats: optimization result and statistics
        """

        # disable warnings because lambdify might generate numerical overflow
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore")

        if x0 is not None:
            self.param_vector.set_values(x0)

        self.regularization = regularization
        self.alpha = alpha

        # batch_size = max(1, int(rel_batch_size * len(df)))
        # rand_idx = np.random.randint(0, len(df), batch_size)
        # minibatch = df.iloc[rand_idx, :]
        minibatch = df.sample(frac=rel_batch_size)

        attributes = minibatch.copy()
        # print("fit attributes", attributes)
        # attribute_names = self.attribute_names

        # print("attribute_names", attribute_names)
        attributes = attributes[self.attribute_names].T.values
        if scale_attrs:
            scaler = StandardScaler()
            attributes = scaler.fit_transform(attributes)

        choices = minibatch["choice"].values
        # print("attributes, attr name", attributes, attribute_names)

        self.attributes = attributes
        self.choices = choices

        num_params = self.n_parameters  # - len(self.attribute_names)

        new_bounds = None
        if bounds is None:
            bounds = [(-2, 2) for _ in range(num_params)]
            mu_indices = [i for i, name in enumerate(
                self.param_vector.names) if 'mu_' in name]
            for i in mu_indices:
                bounds[i] = (.01, 0.9999)
        elif isinstance(bounds, dict):
            new_bounds = [(-2, 2)]*num_params
            for k, v in bounds.items():
                if k in self.param_vector.names:
                    idx = self.param_vector.names.index(k)
                    new_bounds[idx] = v
            mu_indices = [i for i, name in enumerate(
                self.param_vector.names) if 'mu_' in name]
            for i in mu_indices:
                new_bounds[i] = (.01, 0.9999)
            new_bounds[idx] = v
        else:
            raise RuntimeError("bounds is no dict.")
        if new_bounds:
            bounds = new_bounds

        # print("new_bounds:")
        # for b in new_bounds:
        #    print(b)

        best_nll = float('inf')
        best_params = None
        total_iterations = 100

        def print_progress(progress, total, r2=None, nll=None, best_nll=None):
            sys.stdout.write('\r')
            percent = np.round(progress/total*100, 2)
            r2str = ""
            nllstr = ""
            if r2 is not None:
                r2str = "R^2: %+6.2f" % (r2)
            if nll is not None:
                nllstr = "NLL: %7.2f" % (nll)
            if best_nll is not None:
                bnllstr = "best NLL: %7.2f" % (best_nll)

            sys.stdout.write(
                f' Progress: {progress}/{total} samples {percent:.2f}%  ' + r2str + '    ' + nllstr + '    ' + bnllstr + '      \r')
            sys.stdout.flush()

        if sampling_method == "brute":
            num_samples_per_param = max(num_samples_per_param, 3)
            total_iterations = num_samples_per_param ** num_params
            print("Pre-sampling with brute force...")

            def generate_param_combinations(bounds, num_samples_per_param):
                # Validate bounds
                for bound in bounds:
                    if not (len(bound) == 2 and bound[0] < bound[1]):
                        raise ValueError(f"Invalid bounds given: {bound}")

                grid_samples = []

                for bound in bounds:
                    samples = np.linspace(
                        bound[0], bound[1], num_samples_per_param)
                    grid_samples.append(samples)

                # Yield each combination of the full grid
                for combination in product(*grid_samples):
                    yield np.array(combination)

            param_combinations = generate_param_combinations(
                bounds, num_samples_per_param)
            # print("param_combinations", list(param_combinations)[:10])

        elif sampling_method == "lhs":
            print("Pre-sampling with latin hypercubes")

            num_samples_per_param = max(num_samples_per_param, 1)
            total_iterations = num_samples_per_param  # Total samples, not per-param

            def generate_param_combinations(bounds, n):
                num_params = len(bounds)
                if num_params == 0:
                    raise ValueError("No parameters to sample.")

                lhs_samples = lhs(num_params, samples=n)
                param_samples = np.zeros_like(lhs_samples)

                for i, (low, high) in enumerate(bounds):
                    param_samples[:, i] = low + \
                        (high - low) * lhs_samples[:, i]

                for combination in param_samples:
                    yield np.array(combination)

            param_combinations = generate_param_combinations(
                bounds, total_iterations)
            # print("param_combinations", list(param_combinations)[:10])

        else:
            raise RuntimeError(
                "Invalid sampling method. allowed are 'brute' and 'lhs'")

        # total_iterations = len(param_combinations)

        manager = Manager()
        progress = manager.Value('i', 0)
        lock = Lock()

        best_global_params = self.param_vector.to_serial()
        print("global params", best_global_params)

        def track_progress(result, r2=None):
            nonlocal best_nll, best_params
            nll_value, sample = result
            with lock:
                progress.value += 1
                if progress.value % 200 == 0 or progress.value == total_iterations:
                    print_progress(
                        progress.value, total_iterations, r2=r2, nll=nll_value, best_nll=best_nll)

            if (nll_value < float("inf")) and (not np.isnan(nll_value)):
                # and r2 > 0.0): (ideally, otherwise model performs worse than all-zero coefficient model)
                if nll_value < best_nll and ((r2 is not None)):
                    with lock:
                        best_nll = nll_value
                        best_params = sample
                        # print("best_nll", best_nll, "best_params", best_params)
                # else:
                #    print("value", nll_value, "best", best_nll)

        if optimize_globally:
            if parallel:

                raise NotImplementedError(
                    "Parallel solving not yet supported :(")
                # if n_cpus is None:
                #     n_cpus = max(1, int(0.8 * cpu_count()))
                # batch_size = min(2048, len(param_combinations))
                # for i in range(0, total_iterations, batch_size):
                #     print("batch", i, "/", total_iterations)
                #     with Pool(processes=n_cpus) as pool:
                #         for param in param_combinations[i:i+batch_size]:
                #             pool.apply_async(
                #                 self.worker_function, param, callback=track_progress)
                #         pool.close()
                #         pool.join()
            else:
                combi_count = 0
                x1_vals = []
                x2_vals = []
                nll_vals = []
                for param in param_combinations:
                    # print("param", param)
                    combi_count += 1
                    result = self.worker_function(
                        param, attributes, choices, regularization, alpha, alpha_2, noise=pred_noise)
                    r2 = None
                    r2 = self.pseudo_r_squared(param, attributes, choices)
                    # if progress.value % 200 == 0:
                    #    r2 = self.pseudo_r_squared(param, attributes, choices)
                    # print(f"param {param} R2 {r2:.4f} NLL {result[0]:.4f}")
                    track_progress(result, r2)
                    nll_vals.append(result[0])
                    x1_vals.append(param[0])
                    x2_vals.append(param[1])

            print("\n\nBest global params:", best_params,
                  "\nBest global NLL:", best_nll)

            # ----- for debugging: plot parameter space and nll ---
            # # Ensure they're numpy arrays
            # x1_vals = np.array(x1_vals)
            # x2_vals = np.array(x2_vals)
            # nll_vals = np.array(nll_vals)

            # # Get unique x1 and x2 values to define the grid
            # x1_unique = np.sort(np.unique(x1_vals))
            # x2_unique = np.sort(np.unique(x2_vals))

            # # Create lookup for positions in the grid
            # x1_idx = {val: i for i, val in enumerate(x1_unique)}
            # x2_idx = {val: i for i, val in enumerate(x2_unique)}

            # # Create 2D grid and fill with values
            # heatmap = np.empty((len(x2_unique), len(x1_unique)))
            # # fill with NaN in case some combinations are missing
            # heatmap[:] = np.nan

            # for x1, x2, nll in zip(x1_vals, x2_vals, nll_vals):
            #     i = x2_idx[x2]  # row
            #     j = x1_idx[x1]  # column
            #     heatmap[i, j] = nll

            # # Plot
            # plt.figure(figsize=(8, 6))
            # plt.imshow(heatmap, origin='lower', aspect='auto',
            #            extent=[x1_unique[0], x1_unique[-1], x2_unique[0], x2_unique[-1]])
            # plt.colorbar(label='NLL Value')
            # plt.scatter(1.0, -1.0, color="red")
            # plt.xlabel('x1')
            # plt.ylabel('x2')
            # plt.title('Heatmap of x1 vs x2 (colored by NLL)')
            # plt.show()
            # ---------------------------------

            best_global_params = best_params

            # print("number of parameter configurations tried", combi_count)
            # sys.exit(0)

        if (optimize_globally) and (best_nll is None or np.isnan(best_nll) or (not best_nll < float("inf"))):
            raise RuntimeError(
                "Could not find global optimum in pre-solver. aborting...")

        # optimize locally
        def numerical_grad(x, *args):
            epsilon = np.sqrt(np.finfo(float).eps)
            return approx_fprime(x, lambda x_: self.negative_log_likelihood(x_, *args), epsilon)

        x0 = np.array(best_global_params)
        # ---
        # (pre-pre-solver)
        # result = minimize(
        #     self.negative_log_likelihood,
        #     x0=x0,
        #     args=(attributes, choices, regularization, alpha, alpha_2),
        #     method='L-BFGS-B',
        #     bounds=bounds,
        #     options={'disp': disp}
        # )

        # if result.success:
        #     x0 = result.x

        # --
        success = False
        attempts = 0
        max_attempts = 3
        while not success and attempts < max_attempts:
            attempts += 1
            result = minimize(
                self.negative_log_likelihood,
                x0=x0,
                args=(attributes, choices, regularization, alpha, alpha_2),
                method=opt_method,  # L-BFGS-B',
                # eps = eps,
                tol=tol,
                jac=numerical_grad,
                # method='SLSQP',
                # approx_grad=True,
                # maxfun=2000,
                # maxiter=100,
                # jac=True, hess='2-point',
                # jac=True,
                bounds=bounds,
                options={'disp': disp}
            )
            success = result.success

            if not success:

                if attempts == max_attempts - 1:
                    x0 = np.zeros_like(x0) + 1e-8  # try all-zero vector
                else:
                    # slightly disturb initial guess and try again
                    x0 *= np.random.uniform(0.95, 1.05)

        try:
            print("\nNLL", result.fun)
            r2 = self.pseudo_r_squared(result.x, attributes, choices)
            print("R2:", np.round(r2, 3), "\n")
        except:
            pass

        if not result.success or "bad search direction" in result.message:
            messg = f"{opt_method} failed or unstable. Trying non-gradient optimization..."
            warnings.warn(messg)
            # raise RuntimeError(messg)
            result = minimize(
                self.negative_log_likelihood,
                x0=best_global_params,
                args=(attributes, choices, regularization, alpha, alpha_2),
                method='Powell',
                options={'disp': disp})

            try:
                print("\nNLL (non-gradient)", result.fun, "\n")
            except Exception as e:
                warnings.warn("Exception: %s" % str(e))

        # summmarize fit statistics
        param_names = self.param_vector.names
        stats = self.print_stats(param_names, result, f=self.negative_log_likelihood, args=(
            attributes, choices, regularization, alpha, alpha_2))

        warnings.resetwarnings()

        return result, stats

    def compute_log_likelihood(self, params, attributes, choices, noise=0):
        # plain vanilla log likelihood

        # print("compute log likelihood")
        # print("   ", params, len(params))
        # print("   ", attributes, len(attributes), len(attributes[0]))
        # print("   ", choices, len(choices))

        lookup = {alt: i for i, alt in enumerate(
            self.probabilities_func.keys())}
        # print("lookup", lookup)
        ll = 0
        num_obs = attributes.shape[1]
        num_alts = len(self.probabilities_func)
        probs = np.zeros((num_alts, num_obs))
        for j, (_, f) in enumerate(self.probabilities_func.items()):
            # print(*params)
            if noise > 0:
                g = np.random.normal(1.0, noise)
            else:
                g = 1

            p_j = f(*params, *attributes)
            p_j = np.clip(p_j, 0, np.inf)

            # ensure probabilities are in range
            # assert np.all(p_j >= -1e-8) and np.all(
            #   p_j <= 1+1e-8), "Probabilities out of range!\n p<0: %s\n p>1:%s" % (p_j[p_j < 0], p_j[p_j > 1])

            finite_mask = np.isfinite(p_j)
            bounded_mask = (p_j[finite_mask] >= -
                            1e-8) & (p_j[finite_mask] <= 1 + 1e-8)
            assert np.all(bounded_mask), (
                f"Probabilities out of range!"
                # f"p < 0: {p_j[(p_j < -1e-8) & finite_mask]}\n"
                # f"p > 1: {p_j[(p_j > 1 + 1e-8) & finite_mask]}"
            )

            prob = np.clip(p_j * g, 1.0*1e-6, 1.0-1e-6)
            if prob.ndim == 1:
                probs[j, :] = prob
            else:
                probs[j, :] = prob.flatten()
        probs /= np.nansum(probs, axis=0)
        if np.any(probs > 1):
            print("warning probabilities greater than one")
        probs = np.clip(probs, 1e-10, 1.0)
        # print("choices", choices)
        try:
            choice_indices = np.array([lookup[choice] for choice in choices])
        except Exception as e:
            raise RuntimeError(
                "Something went wrong then looking up choices %s in %s" % (choices, lookup))
        ll = np.nansum(np.log(probs[choice_indices, np.arange(num_obs)]))
        return ll

    def negative_log_likelihood(self, params, attributes, choices, regularization='l2', alpha=0.001, alpha_2=0.001, pred_noise=0.0,
                                verbose=False):
        # negative log likelihood with regularization term
        # print("NLL", attributes.shape, choices.shape, params.shape)

        # print("choices", choices)
        # print("params", params)
        # print("attributes", attributes)
        # print("reg", regularization)

        log_likelihood = self.compute_log_likelihood(
            params, attributes, choices, noise=pred_noise)
        regularization_term = 0
        if (regularization == 'l1'):
            regularization_term = alpha * np.nansum(np.abs(params))
        elif (regularization == 'l2'):
            regularization_term = alpha_2 * np.nansum(params**2)
        elif regularization == 'elastic_net':
            regularization_term = alpha *\
                np.nansum(np.abs(params)) + alpha_2 * np.nansum(params**2)

        if verbose:
            print("log likelihood", log_likelihood)
            print("regularization term", regularization_term)

        return_val = -log_likelihood + regularization_term
        if np.isnan(return_val):
            return np.inf
        return return_val

    def pseudo_r_squared(self, params, attributes=None, choices=None):
        # the McFadden pseudo-rsquared against the Null model

        if attributes is None:
            attributes = self.attributes
        if choices is None:
            choices = self.choices
        assert attributes is not None
        assert choices is not None

        params = np.array(params)
        null_params = np.zeros(len(params))
        # only set the beta and asc zero, not the mu
        null_params[:-self.n_mus] = params[:-self.n_mus]
        original_params = list(self.param_vector.to_serial())

        for i, name in enumerate(self.param_vector.names):
            if name.startswith("asc_"):
                null_params[i] = float(params[i])
                # keep only the alternative specific constants

        assert len(null_params) == len(params)

        # Compute the log-likelihood for the null model
        self.set_parameters(null_params)
        null_LL = self.compute_log_likelihood(null_params, attributes, choices)

        # Compute the log-likelihood for the full model
        self.set_parameters(original_params)
        LL = self.compute_log_likelihood(params, attributes, choices)
        n = len(choices)

        k = len(self.attribute_names)
        r2 = 1 - ((LL - k) / null_LL)
        self.set_parameters(original_params)

        # if r2 > 0: # < for debugging
        #    raise RuntimeError("R2>0: %.2f" % r2)
        return r2

    def print_stats(self, param_names, result, alpha=0.05, f=None, args=None):
        # Extract MLEs
        assert 0 <= alpha <= 1
        f_print_stats(param_names, result, alpha, f=f, args=args)

    def worker_function(self, param_combination, attributes, choices, regularization, alpha, alpha2, noise=0):
        # helper function for optimization of the (regularized) NLL
        return self.negative_log_likelihood(param_combination, attributes, choices, regularization, alpha, alpha2, pred_noise=noise), param_combination

    def tweak_asc(self, data, start_params=None, n_iter=20,
                  learn_rate=0.05, rel_batch_size=0.2, pred_noise=0.0,
                  learning_speed=1.8, r2_tol=0.35, verbose=True, ignore_r2=False):
        """ 
        Tweaks alternative-specific constant to the observed frequencies
        :param data: the data with column 'choice' and columns for features
        :param start_params: initial guess (default None, take currently set parameters)
        :param n_iter: number of iterations to do (defaults to 20)
        :param learn_rate: learning rate when updating the ASC value 
        :param rel_batch_size: batch size for robustness (default 0.2)
        :param pred_noise: prediction noise for robustness (default 0) 
        :param learning_speed: 
        :param r2_tol: stop when R-squared drops below this level (default 0.35) 
        :param verbose: enables printing
        :param ignore_r2: no tolerance on r-squared

        :return: tweaked_params, r2
        """

        if verbose:
            print("*** Tweaking ASC values ***")

        attributes = data[[c for c in data.columns if c != "choice"]].T.values
        choices = data["choice"].values

        # tweak the alternative-specific constants for the relative occurences
        if start_params is None:
            if self.param_vector is None:
                raise RuntimeError(
                    "Cannot find starting parameters. either provide via start_params argument or fit the model first.")
            start_params = self.param_vector.to_serial()
            original_params = list(start_params)
        else:
            original_params = list(self.param_vector.to_serial())
            self.set_parameters(start_params)
        # r2_0 = self.pseudo_r_squared(start_params, attributes, choices)
        corr = {}

        broke = False
        for itr in range(n_iter):
            if itr == 0 or itr == n_iter - 1:
                minibatch = data
            else:
                # batch_size = max(1, int(rel_batch_size * len(data)))
                # rand_idx = np.random.randint(0, len(data), batch_size)
                # minibatch = data.iloc[rand_idx, :]
                minibatch = data.sample(frac=rel_batch_size)

            y_pred = self.predict(minibatch, noise=pred_noise)
            y_obs = minibatch["choice"]

            # print("ypred-minibatch", itr)
            targets = list(set(y_obs))
            # for t in targets:
            #    count_obs = list(y_obs).count(t)
            #    count_pred = list(y_pred).count(t)
            #    print("   >", t, count_obs, "obs", count_pred, "pred")

            # if itr == 100:
            #    sys.exit(0)

            n_obs, n_pred = len(y_obs), len(y_pred)
            if itr == n_iter - 1 or broke:
                if verbose:
                    print("\n *Prediction frequencies after tweaking")
            if itr == 0:
                if verbose:
                    print("\n *Prediction frequencies before tweaking")
                for t in targets:
                    corr[t] = 0.0

            eps = 1e-4  # small tolerance for log to avoid log(0)
            for t in targets:
                count_obs = list(y_obs).count(t)
                count_pred = list(y_pred).count(t)
                # correction factor after Reul (2024) and Train (2009)
                assert 0 <= learn_rate <= 1, "learning rate must be in range [0,1]"
                if count_obs > 0:
                    # count_fact = np.log(max(1,count_obs)/n_obs) - np.log(max(1, count_pred)/n_pred)
                    count_fact = np.log(
                        max(eps, count_obs/n_obs)) - np.log(max(eps, count_pred/n_pred))
                    corr[t] = (1-learn_rate) * corr[t] + \
                        (learn_rate)*count_fact
                    corr[t] = np.clip(corr[t], -0.02, 0.02)
                    # print("cound fact", count_fact, "corr", corr[t])
                    # print("          ", t, "obs: %.3f, pred: %.3f" % (count_obs, count_pred), "fact", count_fact)
                else:
                    raise RuntimeError(
                        "No occurence of choice option %s in minibatch" % t)
                if itr == 0 or itr == n_iter - 1 or broke:
                    print("   ", t, "obs: %.3f, pred: %.3f" %
                          (count_obs/n_obs, count_pred/n_pred))
            # print("\n")
            # if itr == 0:
            #    print("\n")
            if itr == n_iter - 1 or broke:
                if verbose:
                    print("\n\n")
            if broke:
                break

            # adapt the ASCs to the difference
            pnames = self.param_vector.names
            old_params = np.array(self.param_vector.to_serial())
            new_params = list(old_params)
            pref_targets = ['asc_%s' % t for t in targets]
            targets = set(pref_targets).intersection(set(pnames))
            for t in targets:
                idx = pnames.index(t)
                new_params[idx] = old_params[idx] + \
                    learning_speed * corr[t[4:]]
            self.set_parameters(new_params)

            r2 = self.pseudo_r_squared(new_params, attributes, choices)
            nll = self.compute_log_likelihood(
                new_params, attributes, choices, noise=0.0)
            if itr % 500 == 0:
                if verbose:
                    print("   ...iteration", itr, "R2",
                          np.round(r2, 4), "NLL", nll)

            if (not ignore_r2) and r2 < r2_tol:  # stop if the overall fit gets to bad
                broke = True

        tweaked_params = new_params
        if broke:
            self.set_parameters(original_params)
            print(
                "WARNING: Could not further improve fit by tweaking ASCs. Restoring original parameters...")
        else:
            self.set_parameters(tweaked_params)

        r2 = self.pseudo_r_squared(
            self.param_vector.to_serial(), attributes, choices)
        if verbose:
            print("R^2:", np.round(r2, 4))

        nll = self.compute_log_likelihood(
            self.param_vector.to_serial(), attributes, choices, noise=0.0)

        if verbose:
            print("NLL:", nll)
            print("---------------------------")

        return tweaked_params, r2
