import difflib
import functools as ft
import logging

from collections.abc import Callable, Sequence
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Literal, cast

import numpy as np
import pandas as pd
import sympy as sp
import xarray as xr

from better_optimize import minimize, root
from preliz.distributions.distributions import Distribution
from scipy import linalg

from gEconpy.classes.containers import SteadyStateResults, SymbolDictionary
from gEconpy.classes.time_aware_symbol import TimeAwareSymbol
from gEconpy.exceptions import (
    GensysFailedException,
    ModelUnknownParameterError,
    PerturbationSolutionNotFoundException,
    SteadyStateNotFoundError,
)
from gEconpy.model.compile import BACKENDS
from gEconpy.model.perturbation import check_bk_condition as _check_bk_condition
from gEconpy.model.perturbation import (
    check_perturbation_solution,
    make_not_loglin_flags,
    override_dummy_wrapper,
    residual_norms,
    statespace_to_gEcon_representation,
)
from gEconpy.model.steady_state import system_to_steady_state
from gEconpy.parser.parse_distributions import CompositeDistribution
from gEconpy.solvers.cycle_reduction import solve_policy_function_with_cycle_reduction
from gEconpy.solvers.gensys import (
    interpret_gensys_output,
    solve_policy_function_with_gensys,
)
from gEconpy.utilities import get_name, postprocess_optimizer_res, safe_to_ss

VariableType = sp.Symbol | TimeAwareSymbol
_log = logging.getLogger(__name__)


def scipy_wrapper(
    f: Callable,
    variables: list[str],
    unknown_var_idxs: np.ndarray[int | bool],
    unknown_eq_idxs: np.ndarray[int | bool],
    f_ss: Callable | None = None,
    include_p=False,
) -> Callable:
    if f_ss is not None:
        if not include_p:

            @ft.wraps(f)
            def inner(ss_values, param_dict):
                given_ss = f_ss(**param_dict)
                ss_dict = SymbolDictionary(zip(variables, ss_values, strict=False)).to_string()
                ss_dict.update(given_ss)
                res = f(**ss_dict, **param_dict)

                if isinstance(res, float | int):
                    return res
                if res.ndim == 1:
                    res = res[unknown_eq_idxs]
                elif res.ndim == 2:
                    res = res[unknown_eq_idxs, :][:, unknown_var_idxs]
                return res
        else:

            @ft.wraps(f)
            def inner(ss_values, p, param_dict):
                given_ss = f_ss(**param_dict)
                ss_dict = SymbolDictionary(zip(variables, ss_values, strict=False)).to_string()
                ss_dict.update(given_ss)

                p_full = np.zeros(unknown_eq_idxs.shape[0])
                p_full[unknown_var_idxs] = p

                res = f(p_full, **ss_dict, **param_dict)

                if isinstance(res, float | int):
                    return res
                if res.ndim == 1:
                    res = res[unknown_eq_idxs]
                elif res.ndim == 2:
                    res = res[unknown_eq_idxs, :][:, unknown_var_idxs]
                return res

    elif not include_p:

        @ft.wraps(f)
        def inner(ss_values, param_dict):
            ss_dict = SymbolDictionary(zip(variables, ss_values, strict=False)).to_string()
            return f(**ss_dict, **param_dict)
    else:

        @ft.wraps(f)
        def inner(ss_values, p, param_dict):
            ss_dict = SymbolDictionary(zip(variables, ss_values, strict=False)).to_string()
            return f(p, **ss_dict, **param_dict)

    return inner


def add_more_ss_values_wrapper(f_ss: Callable | None, known_variables: SymbolDictionary) -> Callable:
    """
    Inject user-provided constant steady state values to the return of the steady state function.

    Parameters
    ----------
    f_ss: Callable, Optional
        Compiled function that maps models parameters to numerical steady state values for variables.

    known_variables: SymbolDictionary
        Numerical values for model variables in the steady state provided by the user. Keys are expected to be string
        variable names, and values floats.

    Returns
    -------
    Callable
        A new version of f_ss whose returns always includes the contents of known_variables.
    """

    @ft.wraps(f_ss)
    def inner(**parameters):
        if f_ss is None:
            return known_variables

        ss_dict = f_ss(**parameters)
        ss_dict.update(known_variables)
        return ss_dict

    return inner


def infer_variable_bounds(variable):
    assumptions = variable.assumptions0
    is_positive = assumptions.get("positive", False)
    is_negative = assumptions.get("negative", False)
    lhs = 1e-8 if is_positive else None
    rhs = -1e-8 if is_negative else None

    return lhs, rhs


def _initialize_x0(optimizer_kwargs, variables, jitter_x0):
    n_variables = len(variables)

    use_default_x0 = "x0" not in optimizer_kwargs
    x0 = optimizer_kwargs.pop("x0", np.full(n_variables, 0.8))

    if use_default_x0:
        negative_idx = [x.assumptions0.get("negative", False) for x in variables]
        x0[negative_idx] = -x0[negative_idx]

    if jitter_x0:
        rng = np.random.default_rng()
        x0 += rng.normal(scale=1e-4, size=n_variables)

    return x0


def validate_policy_function(A, B, C, D, T, R, tol: float = 1e-8, verbose: bool = True) -> None:
    gEcon_matrices = statespace_to_gEcon_representation(A, T, R, tol)

    P, Q, _, _, A_prime, R_prime, S_prime = gEcon_matrices

    resid_norms = residual_norms(B, C, D, Q, P, A_prime, R_prime, S_prime)
    norm_deterministic, norm_stochastic = resid_norms

    if verbose:
        _log.info(f"Norm of deterministic part: {norm_deterministic:0.9f}")
        _log.info(f"Norm of stochastic part:    {norm_deterministic:0.9f}")


def get_known_equation_mask(
    steady_state_system: list[sp.Expr],
    ss_dict: SymbolDictionary[sp.Symbol, float],
    param_dict: SymbolDictionary[sp.Symbol, float],
    tol: float = 1e-8,
) -> np.ndarray:
    sub_dict = ss_dict.copy() | param_dict.copy()
    subbed_system = [eq.subs(sub_dict.to_sympy()) for eq in steady_state_system]

    eq_is_zero_mask = [
        (sp.Abs(subbed_eq) < tol) == True  # noqa: E712
        for eq, subbed_eq in zip(steady_state_system, subbed_system, strict=False)
    ]

    return np.array(eq_is_zero_mask)


def validate_user_steady_state_simple(
    steady_state_system: list[sp.Expr],
    ss_dict: SymbolDictionary[sp.Symbol, float],
    param_dict: SymbolDictionary[sp.Symbol, float],
    tol: float = 1e-8,
) -> None:
    r"""
    Perform a "shallow" validation of user-provided steady-state values.

    Insert provided numeric values into the systesm of steady state equations and check for non-zero residuals. This
    is a "shallow" check in the sense that no effort is made to check dependencies between equations (that is,
    sp.solve is not called). Partial steady states are allowed -- the function simply looks for numeric, non-zero values
    after the provided values are substituted. Therefore, passing an incorrect value that would later cause a numeric
    solver to fail is also not detected.

    For example, the following system would be detected as having an incorrect steady-state: for :math:`x_1 = 0.5` :

    .. math::

        \begin{align}
            x_1 - 1 &= 0 \\
            x_2^ - 3 = 0
        \end{align}

    Because the first equation will reduce to :math:`-0.5` after simple substitution. On the other hand, this system
    would not be marked at :math:`x_1 = 0.5`:

    ..math::

        \begin{align}
            x_1 - x_2 &= 0 \\
            x_2 - x_3 &= 0 \\
            x_3 - 1 &= 0
        \end{align}

    Clearly this can be reduced to :math:`x_1 = 1$`, but no effort is made to perform these substitutions, so the error
    will not be flagged. In general, these substitutions are non-trivial, and attempting to solve results in significant
    time cost.

    Parameters
    ----------
    steady_state_system: list of sp.Expr
        System of model equations with all time indices set to the steady state
    ss_dict: SymbolDictionary
        Dictionary of user-provided steady state values. Expected to have TimeAwareSymbol variables as keys and numeric
        values as values.
    param_dict: SymbolDictionary
        Dictionary of parameter values at which to solve for the steady state. Expected to have Symbol variables as
         keys and numeric values as values.
    tol: float
        Radius around zero within which to consider values as zero. Default is 1e-8.
    """
    sub_dict = ss_dict.copy() | param_dict.copy()
    subbed_system = [eq.subs(sub_dict.to_sympy()) for eq in steady_state_system]

    # This has to use equality to check True -- sympy doesn't know the truth value of e.g. |x - 3| < 1e-8. But it does
    # know that this is NOT the same as True.
    invalid_equation_strings = [
        str(eq)
        for eq, subbed_eq in zip(steady_state_system, subbed_system, strict=False)
        if (sp.Abs(subbed_eq) < tol) == False  #  noqa
    ]

    if len(invalid_equation_strings) > 0:
        msg = (
            "User-provide steady state is not valid. The following equations had non-zero residuals "
            "after subsitution:\n"
        )
        msg += "\n".join(invalid_equation_strings)
        raise ValueError(msg)


class Model:
    """
    A Dynamic Stochastic General Equlibrium (DSGE) Model.

    A ``Model`` is a container class for a DSGE model. It has two primary functions: to store all model primitives
    (variables, parameters, shocks, equations, etc.), and to store compiled functions used to solve the model.

    """

    def __init__(
        self,
        variables: list[TimeAwareSymbol],
        shocks: list[TimeAwareSymbol],
        equations: list[sp.Expr],
        steady_state_relationships: list[sp.Eq],
        param_dict: SymbolDictionary,
        hyper_param_dict: SymbolDictionary,
        deterministic_dict: SymbolDictionary,
        calib_dict: SymbolDictionary,
        priors: dict[str, Distribution],
        f_params: Callable[[np.ndarray, ...], SymbolDictionary],
        f_ss_resid: Callable[[np.ndarray, ...], float],
        f_ss: Callable[[np.ndarray, ...], SymbolDictionary],
        f_ss_error: Callable[[np.ndarray, ...], np.ndarray],
        f_ss_jac: Callable[[np.ndarray, ...], np.ndarray],
        f_ss_error_grad: Callable[[np.ndarray, ...], np.ndarray],
        f_ss_error_hess: Callable[[np.ndarray, ...], np.ndarray],
        f_ss_error_hessp: Callable[[np.ndarray, ...], np.ndarray],
        f_linearize: Callable,
        backend: BACKENDS = "numpy",
        is_linear: bool = False,
    ) -> None:
        """
        Container class for DSGE model primitives and compiled functions.

        In general, users should not need to instantiate this class directly. Instead, use
         :func:`gEconpy.model.build.model_from_gcn` to create a model from a GCN file.

        Parameters
        ----------
        variables: list[TimeAwareSymbol]
            List of variables in the model
        shocks: list[TimeAwareSymbol]
            List of shocks in the model
        equations: list[sp.Expr]
            List of equations in the model
        param_dict: SymbolDictionary
            Dictionary of parameters in the model
        hyper_param_dict: SymbolDictionary
            Dictionary of parameters used by shock distributions
        deterministic_dict: SymbolDictionary
            Dictionary of parameters defined as deterministic functions of other parameters, mapping the
            deterministic parameter Symbols to the expressions defining them.
        calib_dict: SymbolDictionary
            Dictionary of parameters defined as functions of steady-state variables, mapping the calibrated parameter
            Symbols to the expressions defining them.
        priors: dict[str, Distribution]
            Dictionary of prior distributions for the model parameters
        f_params: Callable
            Function that returns a dictionary of parameter values given a dictionary of parameter values
        f_ss_resid: Callable
            Function that takes a dictionary of parameter values theta and steady-state variable values x_ss and
            evaluates the system of model equations f(x_ss, theta) = 0.
        f_ss: Callable
            Function that takes current parameter values and returns a dictionary of steady-state values.
        f_ss_error: Callable, optional
            Function that takes a dictionary of parameter values theta and steady-state variable values x_ss and returns
            a scalar error measure of x_ss given theta.
            If None, the sum of squared residuals returned by f_ss_resid is used.
        f_ss_error_grad: Callable, optional
            Function that takes a dictionary of parameter values theta and steady-state variable values x_ss and returns
            the gradients of the error function f_ss_error with respect to the steady-state variable values x_ss

            If f_ss_error is not provided, an error will be raised if a gradient function is passed.
        f_ss_error_hess: Callable, optional
            Function that takes a dictionary of parameter values theta and steady-state variable values x_ss and returns
            the Hessian of the error function f_ss_error with respect to the steady-state variable values x_ss

            If f_ss_error is not provided, an error will be raised if a gradient function is passed.

        f_ss_error_hessp: Callable, optional
            Function that takes a dictionary of parameter values theta and steady-state variable values x_ss and returns
            the Hessian-vector product of the error function f_ss_error with respect to the steady-state variable
            values x_ss.

        f_ss_jac: Callable, optional
            Function that takes a dictionary of parameter values theta and steady-state variable values x_ss and returns
            the Jacobian of the system of model equations f(x_ss, theta) = 0 with respect to the steady-state variable
            values x_ss.

        f_linearize: Callable, optional
            Function that takes a dictionary of parameter values theta and steady-state variable values x_ss and returns
            the first-order approximation of the model around the steady state.
        """
        self._variables = variables
        self._shocks = shocks
        self._equations = equations
        self._params = list(param_dict.to_sympy().keys())
        self.is_linear = is_linear

        self._hyper_params = list(hyper_param_dict.to_sympy().keys())
        self._deterministic_params = list(deterministic_dict.to_sympy().keys())
        self._calibrated_params = list(calib_dict.to_sympy().keys())

        self._steady_state_relationships = steady_state_relationships

        self._all_names_to_symbols = {
            get_name(x, base_name=True): x
            for x in (self.variables + self.params + self.calibrated_params + self.deterministic_params + self.shocks)
        }

        self._priors = priors

        self._default_params = param_dict.copy()
        self.f_params = f_params
        self.f_ss_resid = f_ss_resid

        self.f_ss_error = f_ss_error
        self.f_ss_error_grad = f_ss_error_grad
        self.f_ss_error_hess = f_ss_error_hess
        self.f_ss_error_hessp = f_ss_error_hessp

        self.f_ss = f_ss
        self.f_ss_jac = f_ss_jac

        if backend == "numpy":
            f_linearize = override_dummy_wrapper(f_linearize, "not_loglin_variable")
        self.f_linearize = f_linearize

    @property
    def variables(self) -> list[TimeAwareSymbol]:
        """
        List of variables in the model, stored as Sympy symbols.

        Variables are associated with the model;s endogenous states, identified by the presence of a time subscript.
        """
        return self._variables

    @property
    def shocks(self) -> list[TimeAwareSymbol]:
        """
        List of shocks in the model, stored as Sympy symbols.

        Shocks are exogenous variables in the model, and the source of stochasticity in the model.
        """
        return self._shocks

    @property
    def equations(self) -> list[sp.Expr]:
        """List of equations in the model, stored as Sympy expressions."""
        return self._equations

    @property
    def params(self) -> list[sp.Symbol]:
        """
        List of parameters in the model, stored as Sympy ``Symbol``s.

        Parameters are fixed values in the model, associated with the structural equations of the model. These are
        sometimes called "deep parameters" because of their (supposed) microeconomic foundations.
        """
        return self._params

    @property
    def hyper_params(self) -> list[sp.Symbol]:
        """
        List of hyperparameters in the model, stored as Sympy ``Symbol``s.

        Hyperparameters are parameters associated with the distribution of shocks in the model, for example the
        standard deviation of a normally distributed shock.
        """
        return self._hyper_params

    @property
    def deterministic_params(self) -> list[sp.Symbol]:
        """
        List of deterministic parameters in the model, stored as Sympy ``Symbol``s.

        Deterministic parameters are parameters defined as functions of other parameters in the model. They are
        not directly calibrated, but are instead derived deterministically from other parameters.
        """
        return self._deterministic_params

    @property
    def param_priors(self) -> dict[str, Distribution]:
        """
        Dictionary of prior distributions for the model parameters.

        The dictionary keys are parameter names, and the values are instances of a subclass of ``preliz.Distribution``.
        """
        return self._priors[0]

    @property
    def shock_priors(self) -> dict[str, CompositeDistribution]:
        """
        Dictionary of prior distributions for the model shocks.

        The dictionary keys are shock names, and the values are instances of a subclass of ``preliz.Distribution``.
        """
        return self._priors[1]

    @property
    def calibrated_params(self) -> list[sp.Symbol]:
        """
        List of calibrated parameters in the model, stored as Sympy ``Symbol``s.

        Calibrated parameters are pseudo-parameters whose values are an implicit function of the model parameters.
        Each calibrated parameter must be associated with a function of steady-state variables. This function is added
        to the model equations when solving for the steady state, and the calibrated parameter is then solved for
        numerically.
        """
        return self._calibrated_params

    @property
    def steady_state_relationships(self) -> list[sp.Eq]:
        """List of model equations, evaluated at the deterministic steady state."""
        return self._steady_state_relationships

    def parameters(self, **updates: float) -> SymbolDictionary[str, float]:
        """
        Compute the full set of free parameters for the model, including deterministic parameters.

        Calibrated parameters are not returned by this function. These are computed as part of the steady-state
        solution.

        If a parameter is not provided in the updates, the default value (as defined in the model GCN file) is used.

        Parameters
        ----------
        updates: float
            Parameters to update. These are passed as keyword arguments, with the parameter name as the keyword and the
            new value as the value.

        Returns
        -------
        SymbolDictionary
            Dictionary of parameter names and values.
        """
        # Remove deterministic parameters for updates. These can appear **self.parameters() into a fitting function
        deterministic_names = [x.name for x in self.deterministic_params]
        updates = {k: v for k, v in updates.items() if k not in deterministic_names}

        # Check for unknown updates (typos, etc)
        param_dict = self._default_params.copy()
        unknown_updates = set(updates.keys()) - set(param_dict.keys())
        if unknown_updates:
            raise ModelUnknownParameterError(list(unknown_updates))
        param_dict.update(updates)

        return self.f_params(**param_dict).to_string()

    def get(self, name: str) -> sp.Symbol:
        """
        Get a model variable or parameter by name.

        Variables are returned as TimeAwareSymbols, and parameters are returned as regular Sympy Symbols. If the name
        ends with "_ss", the steady-state version of the variable is returned.

        Parameters
        ----------
        name: str
            Name of the variable or parameter to retrieve

        Returns
        -------
        sp.Symbol
            The requested variable or parameter.
        """
        ss_requested = name.endswith("_ss")
        name = name.removesuffix("_ss")

        result = self._all_names_to_symbols.get(name)
        if result is None:
            close_match = difflib.get_close_matches(name, [get_name(x) for x in self._all_names_to_symbols], n=1)[0]
            raise IndexError(f"Did not find {name} among model objects. Did you mean {close_match}?")
        if ss_requested:
            return result.to_ss()
        return result

    def _validate_provided_steady_state_variables(self, user_fixed_variables: Sequence[str]):
        # User is allowed to pass the variable name either with or without the _ss suffix. Begin by normalizing the
        # inputs
        fixed_variables_normed = [x.removesuffix("_ss") for x in user_fixed_variables]

        # Check for duplicated values. This should only be possible if the user passed both `x` and `x_ss`.
        counts = [fixed_variables_normed.count(x) for x in fixed_variables_normed]
        duplicates = [x for x, c in zip(fixed_variables_normed, counts, strict=False) if c > 1]
        if len(duplicates) > 0:
            raise ValueError(
                "The following variables were provided twice (once with a _ss prefix and once without):\n"
                f"{', '.join(duplicates)}"
            )

        # Check that all variables are in the model
        model_variable_names = [x.base_name for x in self.variables]
        unknown_fixed = set(fixed_variables_normed) - set(model_variable_names)

        if len(unknown_fixed) > 0:
            raise ValueError(
                f"The following variables or calibrated parameters were given fixed steady state values but are "
                f"unknown to the model: {', '.join(unknown_fixed)}"
            )

    def steady_state(
        self,
        how: Literal["analytic", "root", "minimize"] = "analytic",
        use_jac=True,
        use_hess=True,
        use_hessp=False,
        progressbar=True,
        optimizer_kwargs: dict | None = None,
        verbose=True,
        bounds: dict[str, tuple[float, float]] | None = None,
        fixed_values: dict[str, float] | None = None,
        jitter_x0: bool = False,
        **updates: float,
    ) -> SteadyStateResults:
        r"""
        Solve for the deterministic steady state of the DSGE model.

        A steady state is defined as the fixed point in the system of  nonlinear equations that describe the model's
        equilibrium. Given a system of model equations :math:`F(x_{t+1}, x_t, x_{t-1}, \varepsilon_t)`, the steady state
        is defined as a state vector :math:`\bar{x}` such that

        .. math::

            F(\bar{x}, \bar{x}, \bar{x}, 0) = 0

        where :math:`0` is the zero vector. At the point :math:`\bar{x}`, the system will not change, absent an
        exogenous shock.

        The steady state is a key concept in DSGE modeling, as it is the point around which the model is linearized.

        Parameters
        ----------
        how: str, one of ['analytic', 'root', 'minimize'], default: 'analytic'
            Method to use to solve for the steady state. If ``'analytic'``, the model is solved analytically using
            user-provided steady-state equations. This is only possible if the steady-state equations are fully
            defined. If ``'root'``, the steady state is solved using a root-finding algorithm. If ``'minimize'``, the
            steady state is solved by minimizing a squared error loss function.

        use_jac: bool, default: True
            Flag indicating whether to use the Jacobian of the error function when solving for the steady state. Ignored
            if ``how`` is 'analytic'.

        use_hess: bool, default: False
            Flag indicating whether to use the Hessian of the error function when solving for the steady state. Ignored
            if ``how`` is not 'minimize'

        use_hessp: bool, default: True
            Flag indicating whether to use the Hessian-vector product of the error function when solving for the
            steady state. This should be preferred over ``use_hess`` if your chosen method supports it. For larger
            problems it is substantially more performant.
            Ignored if ``how`` not "minimize".

        progressbar: bool, default: True
            Flag indicating whether to display a progress bar when solving for the steady state.

        optimizer_kwargs: dict, optional
            Keyword arguments passed to either scipy.optimize.root or scipy.optimize.minimize, depending on the value of
            ``how``. Common argments include:

            - 'method': str,
                The optimization method to use. Default is ``'hybr'`` for ``how = 'root'`` and ``trust-krylov`` for
                ``how = 'minimize'``
            - 'maxiter': int,
                The maximum number of iterations to use. Default is 5000. This argument will be automatically renamed
                to match the argument expected by different optimizers (for example, the ``'hybr'`` method uses
                ``maxfev``).

        verbose: bool, default True
            If true, print a message about convergence (or not) to the console .

        bounds: dict, optional
            Dictionary of bounds for the steady-state variables. The keys are the variable names and the values are
            tuples of the form (lower_bound, upper_bound). These are passed to the scipy.optimize.minimize function,
            see that docstring for more information.

        fixed_values: dict, optional
            Dictionary of fixed values for the steady-state variables. The keys are the variable names and the values
            are the fixed values. These are not check for validity, and passing an inaccurate value may result in the
            system becoming unsolvable.

        jitter_x0: bool
            Whether to apply some small N(0, 1e-4) jitter to the initial point

        **updates: float, optional
            Parameter values at which to solve the steady state. Passed to self.parameters. If not provided, the default
            parameter values (those originally defined during model construction) are used.

        Returns
        -------
        steady_state: SteadyStateResults
            Dictionary of steady-state values

        """
        if optimizer_kwargs is None:
            optimizer_kwargs = {}

        if fixed_values is None:
            f_ss = self.f_ss

        else:
            self._validate_provided_steady_state_variables(list(fixed_values.keys()))
            fixed_symbols = [safe_to_ss(self.get(x)) for x in fixed_values]

            fixed_dict = SymbolDictionary(
                dict(zip(fixed_symbols, fixed_values.values(), strict=False)),
            ).to_string()

            f_ss = add_more_ss_values_wrapper(self.f_ss, fixed_dict)

        # This logic could be made a lot of complex by looking into solver-specific arguments passed via
        # "options"
        tol = optimizer_kwargs.get("tol", 1e-8)

        param_dict = self.parameters(**updates)
        ss_dict = SteadyStateResults()
        ss_system = system_to_steady_state(self.equations, self.shocks)
        unknown_eq_idx = np.full(len(ss_system), True)

        # The default value is analytic, because that's best if the user gave everything we need to proceed. If he gave
        # nothing though, use minimize as a fallback default.
        if how == "analytic" and f_ss is None:
            how = "minimize"
        else:
            # If we have at least some user information, check if its is complete. If it's not, we will minimize
            # with the user-provided values fixed.
            ss_dict = f_ss(**param_dict) if f_ss is not None else ss_dict
            if self.is_linear:
                # TODO: This is a hack, but if we're a linear model, we need to set all the steady state values
                #  to zero. But we don't want to modify the underlying f_ss function, so modify it here.
                ss_dict = SteadyStateResults({x.to_ss(): 0 for x in self.variables}).to_string()

            if len(ss_dict) != 0 and len(ss_dict) != len(self.variables):
                if self.is_linear:
                    raise ValueError(
                        "If a model is declared linear, the steady state must be provided for all variables."
                    )
                if how == "root":
                    zero_eq_mask = get_known_equation_mask(
                        steady_state_system=ss_system,
                        ss_dict=ss_dict,
                        param_dict=param_dict,
                        tol=tol,
                    )
                    if sum(zero_eq_mask) != len(ss_dict):
                        n_eliminated = sum(zero_eq_mask)
                        raise ValueError(
                            'Solving a partially provided steady state with how = "root" is only allowed if applying '
                            f"the given values results in a new square system.\n"
                            f"Found: {len(ss_dict)} provided steady state value{'s' if len(ss_dict) != 1 else ''}\n"
                            f"Eliminated: {n_eliminated} equation{'s' if n_eliminated != 1 else ''}."
                        )
                    unknown_eq_idx = ~zero_eq_mask
                else:
                    how = "minimize"

            # Or, if we have everything, we're done.
            elif len(ss_dict) == len(self.variables):
                resid = self.f_ss_resid(**param_dict, **ss_dict)
                success = np.allclose(resid, 0.0, atol=1e-8)
                ss_dict.success = success
                if not success:
                    _log.warning(f"Steady State was not found. Sum of square residuals: {np.square(resid).sum()}")
                return ss_dict

        # Quick and dirty check of user-provided steady-state validity. This is NOT robust at all.
        validate_user_steady_state_simple(
            steady_state_system=ss_system,
            ss_dict=ss_dict,
            param_dict=param_dict,
            tol=tol,
        )

        ss_variables = [x.to_ss() for x in self.variables] + list(self.calibrated_params)

        known_variables = [] if f_ss is None else list(f_ss(**self.parameters()).to_sympy().keys())

        vars_to_solve = [var for var in ss_variables if var not in known_variables]
        unknown_var_idx = np.array([x in vars_to_solve for x in ss_variables], dtype="bool")

        if how == "root":
            res = self._solve_steady_state_with_root(
                f_ss=f_ss,
                use_jac=use_jac,
                vars_to_solve=vars_to_solve,
                unknown_var_idx=unknown_var_idx,
                unknown_eq_idx=unknown_eq_idx,
                progressbar=progressbar,
                optimizer_kwargs=optimizer_kwargs,
                jitter_x0=jitter_x0,
                **updates,
            )

        elif how == "minimize":
            res = self._solve_steady_state_with_minimize(
                f_ss=f_ss,
                use_jac=use_jac,
                use_hess=use_hess,
                use_hessp=use_hessp,
                vars_to_solve=vars_to_solve,
                unknown_var_idx=unknown_var_idx,
                unknown_eq_idx=unknown_var_idx,
                progressbar=progressbar,
                bounds=bounds,
                optimizer_kwargs=optimizer_kwargs,
                jitter_x0=jitter_x0,
                **updates,
            )
        else:
            raise NotImplementedError()

        provided_ss_values = f_ss(**param_dict).to_sympy() if f_ss is not None else {}
        optimizer_results = SymbolDictionary({var: res.x[i] for i, var in enumerate(vars_to_solve)})
        res_dict = optimizer_results | provided_ss_values
        res_dict = SteadyStateResults({x: res_dict[x] for x in ss_variables}).to_string()

        return postprocess_optimizer_res(
            res=res,
            res_dict=res_dict,
            f_resid=ft.partial(self.f_ss_resid, **param_dict),
            f_jac=ft.partial(self.f_ss_error_grad, **param_dict),
            tol=tol,
            verbose=verbose,
        )

    def _evaluate_steady_state(self, **updates: float):
        param_dict = self.parameters(**updates)
        ss_dict = self.f_ss(**param_dict)

        return self.f_ss_resid(**param_dict, **ss_dict)

    def _solve_steady_state_with_root(
        self,
        f_ss,
        use_jac: bool = True,
        vars_to_solve: list[TimeAwareSymbol] | None = None,
        unknown_var_idx: np.ndarray | None = None,
        unknown_eq_idx: np.ndarray | None = None,
        progressbar: bool = True,
        optimizer_kwargs: dict | None = None,
        jitter_x0: bool = False,
        **param_updates,
    ):
        if optimizer_kwargs is None:
            optimizer_kwargs = {}
        optimizer_kwargs = deepcopy(optimizer_kwargs)

        maxiter = optimizer_kwargs.pop("maxiter", 5000)
        method = optimizer_kwargs.pop("method", "hybr")

        if "options" not in optimizer_kwargs:
            optimizer_kwargs["options"] = {}

        if method in ["hybr", "df-sane"]:
            optimizer_kwargs["options"].update({"maxfev": maxiter})
        else:
            optimizer_kwargs["options"].update({"maxiter": maxiter})

        x0 = _initialize_x0(optimizer_kwargs, vars_to_solve, jitter_x0)

        param_dict = self.parameters(**param_updates)
        wrapper = ft.partial(
            scipy_wrapper,
            variables=vars_to_solve,
            unknown_var_idxs=unknown_var_idx,
            unknown_eq_idxs=unknown_eq_idx,
            f_ss=f_ss,
        )

        f = wrapper(self.f_ss_resid)
        f_jac = wrapper(self.f_ss_jac) if use_jac else None

        with np.errstate(all="ignore"):
            return root(
                f=f,
                x0=x0,
                args=(param_dict,),
                jac=f_jac,
                method=method,
                progressbar=progressbar,
                **optimizer_kwargs,
            )

    def _solve_steady_state_with_minimize(
        self,
        f_ss,
        use_jac: bool = True,
        use_hess: bool = False,
        use_hessp: bool = True,
        vars_to_solve: list[str] | None = None,
        unknown_var_idx: np.ndarray | None = None,
        unknown_eq_idx: np.ndarray | None = None,
        progressbar: bool = True,
        optimizer_kwargs: dict | None = None,
        jitter_x0: bool = False,
        bounds: dict[str, tuple[float, float]] | None = None,
        **param_updates,
    ):
        if optimizer_kwargs is None:
            optimizer_kwargs = {}
        optimizer_kwargs = deepcopy(optimizer_kwargs)

        x0 = _initialize_x0(optimizer_kwargs, vars_to_solve, jitter_x0)
        tol = optimizer_kwargs.pop("tol", 1e-30)

        user_bounds = {} if bounds is None else bounds
        bound_dict = {x.name: infer_variable_bounds(x) for x in vars_to_solve}
        bound_dict.update(user_bounds)

        bounds = [bound_dict[x.name] for x in vars_to_solve]
        has_bounds = any(x != (None, None) for x in bounds)

        method = optimizer_kwargs.pop("method", "trust-ncg" if not has_bounds else "trust-constr")
        if method not in ["trust-constr", "L-BFGS-B", "powell"]:
            has_bounds = False

        maxiter = optimizer_kwargs.pop("maxiter", 5000)
        if "options" not in optimizer_kwargs:
            optimizer_kwargs["options"] = {}
        optimizer_kwargs["options"].update({"maxiter": maxiter})
        if method == "L-BFGS-B":
            optimizer_kwargs["options"].update({"maxfun": maxiter})

        param_dict = self.parameters(**param_updates)

        wrapper = ft.partial(
            scipy_wrapper,
            variables=vars_to_solve,
            unknown_var_idxs=unknown_var_idx,
            unknown_eq_idxs=unknown_eq_idx,
            f_ss=f_ss,
        )

        if use_hess and use_hessp:
            _log.warning("Both use_hess and use_hessp are set to True. use_hessp will be used.")
            use_hess = False

        f = wrapper(self.f_ss_error)
        f_jac = wrapper(self.f_ss_error_grad) if use_jac else None
        f_hess = wrapper(self.f_ss_error_hess) if use_hess else None
        f_hessp = wrapper(self.f_ss_error_hessp, include_p=True) if use_hessp else None

        return minimize(
            f=f,
            x0=x0,
            args=(param_dict,),
            jac=f_jac,
            hess=f_hess,
            hessp=f_hessp,
            method=method,
            bounds=bounds if has_bounds else None,
            tol=tol,
            progressbar=progressbar,
            **optimizer_kwargs,
        )

    def linearize_model(
        self,
        order: Literal[1] = 1,
        log_linearize: bool = True,
        not_loglin_variables: list[str] | None = None,
        steady_state: dict | None = None,
        loglin_negative_ss: bool = False,
        steady_state_kwargs: dict | None = None,
        verbose: bool = True,
        **parameter_updates,
    ):
        r"""
        Linearize the model around the deterministic steady state.

        Parameters
        ----------
        order: int, default: 1
            Order of the Taylor expansion to use. Currently only first order linearization is supported.
        log_linearize: bool, default: True
            If True, all variables are log-linearized. If False, all variables are left in levels.
        not_loglin_variables: list of strings, optional
            List of variables to not log-linearize. If provided, these variables will be left in levels, while all
            others will be log-linearized. Ignored if log_linearize is False.
        steady_state: dict, optional
            Dictionary of steady-state values. If provided, these values will be used to linearize the model. If not
            provided, the steady state will be solved for using the ``steady_state`` method.
        loglin_negative_ss: bool, default: False
            If True, variables with negative steady-state values will be log-linearized. While technically possible,
            this is not recommended, as it can lead to incorrect results. Ignored if log_linearize is False.
        steady_state_kwargs: dict, optional
            Keyword arguments passed to the ``steady_state`` method. Ignored if a steady-state solution is provided
        verbose: bool, default: True
            Flag indicating whether to print the linearization results to the terminal.
        parameter_updates: dict
            New parameter values at which to linearize the model. Unspecified values will be taken from the initial
            values set in the GCN file.

            .. warning::

                If a steady state is provided, these values will *not* be used to update that solution! This can lead
                to an inconsistent linearization. The user is responsible for ensuring consistency in this case.

        Returns
        -------
        A: np.ndarray
            Jacobian matrix of the model with respect to :math:`x_{t+1}` evaluated at the steady state, right-multiplied
            by the diagonal matrix :math:`T`.
        B: np.ndarray
            Jacobian matrix of the model with respect to :math:`x_t` evaluated at the steady state, right-multiplied
            by the diagonal matrix :math:`T`.
        C: np.ndarray
            Jacobian matrix of the model with respect to :math:`x_{t-1}` evaluated at the steady state, right-multiplied
            by the diagonal matrix :math:`T`.
        D: np.ndarray
            Jacobian matrix of the model with respect to :math:`\varepsilon_t` evaluated at the steady state.

        Examples
        --------
        Given a DSGE model of the form:

        .. math::

            F(x_{t+1}, x_t, x_{t-1}, \varepsilon_t) = 0

        The "solution" to the model would be a policy function :math:`g(x_t, \varepsilon_t)`, such that:

        .. math::

            x_{t+1} = g(x_t, \varepsilon_t)

        With the exception of toy models, this policy function is not available in closed form. Instead, the model is
        linearized around the deterministic steady state, which is a fixed point in the system of equations. The linear
        approximation to the model is then used to approximate the policy function. Let :math:`\bar{x}` denote the
        deterministic steady state such that:

        .. math::

            F(\bar{x}, \bar{x}, \bar{x}, 0) = 0.

        A first-order Taylor expansion about (:math:`\bar{x}`, :math:`\bar{x}`, :math:`\bar{x}`, 0) yields

        .. math::

            A (x_{t+1} - \bar{x}) + B (x_t - \bar{x}) + C (x_{t-1} - \bar{x}) + D \varepsilon_t = 0,

        where the Jacobian matrices evaluated at the steady state are

        .. math::

            A = \left .\ frac{\partial F}{\partial x_{t+1}} \right |_{(\bar{x},\bar{x},\bar{x},0)}, \quad
            B = \left .\ frac{\partial F}{\partial x_t} \right |_{(\bar{x},\bar{x},\bar{x},0)}, \quad
            C = \left .\ frac{\partial F}{\partial x_{t-1}} \right|_{(\bar{x},\bar{x},\bar{x},0)}, \quad
            D = \left .\ frac{\partial F}{\partial \varepsilon_t} \right|_{(\bar{x},\bar{x},\bar{x},0)}

        It is common to perform a change of variables to log-linearize the model. Define a log-state vector,
        :math:`\tilde{x}_t = \log(x_t)`, with steady state :math:`\tilde{x}_{ss} = \log(\bar{x})`. We get back to the
        original variables by exponentiating the log-state vector.

        .. math::

            F(\exp(\tilde{x}_{t+1}), \exp(\tilde{x}_t), \exp(\tilde{x}_{t-1}), \varepsilon_t) = 0

        Taking derivaties with respect to :math:`\tilde{x}_t`, the linearized model is then:

        .. math::
            :nowrap:

            \[
            A \exp(\tilde{x}_{ss}) (\tilde{x}_{t+1} - \tilde{x}_{ss}) + B \exp(\tilde{x}_{ss}) (\tilde{x}_t -
            \tilde{x}_{ss}) + C \exp(\tilde{x}_{ss}) (\tilde{x}_{t-1} - \tilde{x}_{ss}) + D \varepsilon_t = 0
            \]

        Note that :math:`\tilde{x} - \tilde{x}_{ss} = \log(x - \bar{x}) = \log \left (\frac{x}{\bar{x}} \right )` is
        the approximate percent deviation of the variable from its steady state.

        The above derivation holds on a variable-by-variable basis. Some variables can be logged and others left in
        levels, all that is required is right-multiplication by a diagonal matrix of the form:

        .. math::

            T = \text{Diagonal}(\{h(x_1), h(x_2), \ldots, h(x_n)\})

        Where :math:`h(x_i) = 1` if the variable is left in levels, and :math:`h(x_i) = \exp(\tilde{x}_{ss})` if the
        variable is logged. This function returns the matrices :math:`AT`, :math:`BT`, :math:`CT`, and :math:`D`.
        """
        if order != 1:
            raise NotImplementedError("Only first order linearization is currently supported.")
        if steady_state_kwargs is None:
            steady_state_kwargs = {}
        if verbose not in steady_state_kwargs:
            steady_state_kwargs["verbose"] = verbose

        if self.is_linear:
            # If the model is linear, the linearization is already done; don't do it again
            log_linearize = False

        param_dict = self.parameters(**parameter_updates)

        if steady_state is None:
            if self.is_linear:
                steady_state = self.f_ss(**self.parameters(**param_dict))
            else:
                steady_state = self.steady_state(
                    **self.parameters(**param_dict),
                    **steady_state_kwargs,
                )

        not_loglin_flags = make_not_loglin_flags(
            variables=self.variables,
            calibrated_params=self.calibrated_params,
            steady_state=steady_state,
            log_linearize=log_linearize,
            not_loglin_variables=not_loglin_variables,
            loglin_negative_ss=loglin_negative_ss,
            verbose=verbose,
        )

        A, B, C, D = self.f_linearize(**param_dict, **steady_state, not_loglin_variable=not_loglin_flags)

        # Using A.dtype to avoid hard-coding float64 (we might be using float32)
        # The reason for casting is mostly D, which sometimes comes out as an int32/64 array

        return [np.ascontiguousarray(x, dtype=A.dtype) for x in [A, B, C, D]]

    def solve_model(
        self,
        solver="cycle_reduction",
        log_linearize: bool = True,
        not_loglin_variables: list[str] | None = None,
        order: Literal[1] = 1,
        loglin_negative_ss: bool = False,
        steady_state: dict | None = None,
        steady_state_kwargs: dict | None = None,
        tol: float = 1e-8,
        max_iter: int = 1000,
        verbose: bool = True,
        on_failure="error",
        **parameter_updates,
    ) -> tuple[np.ndarray | None, np.ndarray | None]:
        r"""
        Solve for the linear approximation to the policy function via perturbation.

        Parameters
        ----------
        solver: str, default: 'cycle_reduction'
            Name of the algorithm to solve the linear solution. Currently "cycle_reduction" and "gensys" are supported.
            Following Dynare, cycle_reduction is the default, but note that gEcon uses gensys.
        log_linearize: bool, default: True
            Whether to log-linearize the model. If False, the model will be solved in levels.
        not_loglin_variables: list of strings, optional
            Variables to not log linearize when solving the model. Variables with steady state values close to zero
            (or negative) will be automatically selected to not log linearize. Ignored if log_linearize is False.
        order: int, default: 1
            Order of taylor expansion to use to solve the model. Currently only 1st order approximation is supported.
        steady_state: dict, optional
            Dictionary of steady-state solutions. If not provided, the steady state will be solved for using the
            ``steady_state`` method.
        steady_state_kwargs: dict, optional
            Keyword arguments passed to the `steady_state` method. Ignored if a steady-state solution is provided
            via the steady_state argument, Default is None.
        loglin_negative_ss: bool, default is False
            Whether to force log-linearization of variable with negative steady-state. This is impossible in principle
            (how can :math:`exp(x_ss)` be negative?), but can still be done; see the docstring for
            :func:`perturbation.linearize_model` for details. Use with caution, as results will not correct. Ignored if
            log_linearize is False.
        tol: float, default 1e-8
            Desired level of floating point accuracy in the solution
        max_iter: int, default: 1000
            Maximum number of cycle_reduction iterations. Not used if solver is 'gensys'.
        verbose: bool, default: True
            Flag indicating whether to print solver results to the terminal
        on_failure: str, one of ['error', 'ignore'], default: 'error'
            Instructions on what to do if the algorithm to find a linearized policy matrix. "Error" will raise an error,
            while "ignore" will return None. "ignore" is useful when repeatedly solving the model, e.g. when sampling.
        parameter_updates: dict
            New parameter values at which to solve the model. Unspecified values will be taken from the initial values
            set in the GCN file.

        Returns
        -------
        T: np.ndarray, optional
            Transition matrix, approximated to the requested order. Represents the policy function, governing agent's
            optimal state-conditional actions. If the solver fails, None is returned instead.

        R: np.ndarray, optional
            Selection matrix, approximated to the requested order. Represents the state- and agent-conditional
            transmission of stochastic shocks through the economy. If the solver fails, None is returned instead.

        Examples
        --------
        This method solves the model by linearizing it around the deterministic steady state, and then solving for the
        policy function using a perturbation method. We begin with a model defined as a function of the form:

        .. math::
           :nowrap:

           \[
           \mathbb{E} \left [ F(x_{t+1}, x_t, x_{t-1}, \varepsilon_t) \right ] = 0
           \]

        The linear approximation is then given by the matrices :math:`A`, :math:`B`, :math:`C`, and :math:`D`, as:

        .. math::
           :nowrap:

           \[
           A \hat{x}_{t+1} + B \hat{x}_t + C \hat{x}_{t-1} + D \varepsilon_t = 0
           \]

        where :math:`\hat{x}_t = x_t - \bar{x}` is the deviation of the state vector from its steady state (again,
        potentially in logs). A solution to the model seeks a function:

        .. math::
           :nowrap:

           \[
           x_t = g(x_{t-1}, \varepsilon_t)
           \]

        This implies that :math:`x_{t+1} = g(x_t, \varepsilon_{t+1})`, allowing us to write the model as:

        .. math::
           :nowrap:

           \[
           F_g(x_{t-1}, \varepsilon_t, \varepsilon_{t+1}) =
           f(g(g(x_{t-1}, \varepsilon_t), \varepsilon_{t+1}),
             g(x_{t-1}, \varepsilon_t), x_{t-1}, \varepsilon_t) = 0
           \]

        To lighten notation, define:

        .. math::
           :nowrap:

           \[
           u = \varepsilon_t, \quad
           u_+ = \varepsilon_{t+1}, \quad
           \hat{x} = x_{t-1} - \bar{x} \\
           f_{x_+} = \left. \frac{\partial F_g}{\partial x_{t+1}} \right |_{\bar{x}, \bar{x}, \bar{x}, 0}, \quad
           f_x = \left. \frac{\partial F_g}{\partial x_t}  \right |_{\bar{x}, \bar{x}, \bar{x}, 0}, \\
           f_{x_-} = \left. \frac{\partial F_g}{\partial x_{t-1}}  \right |_{\bar{x}, \bar{x}, \bar{x}, 0}, \quad
           f_u = \left. \frac{\partial F_g}{\partial u}  \right |_{\bar{x}, \bar{x}, \bar{x}, 0} \\
           g_x = \left. \frac{\partial g}{\partial x_{t-1}}  \right |_{\bar{x}, \bar{x}, \bar{x}, 0}, \quad
           g_u = \left. \frac{\partial g}{\partial \varepsilon_t}  \right |_{\bar{x}, \bar{x}, \bar{x}, 0}
           \]

        Under this new notation, the system is:

        .. math::
           :nowrap:

           \[
           F_g(x_-, u, u_+) = f(g(g(x_-, u), u_+), g(x, u), x_-, u) = 0
           \]

        The function :math:`g` is unknown, but is implicitly defined by this expression, and can be approximated by a
        first order Taylor expansion around the steady state. The linearized system is then:

        .. math::
           :nowrap:

           \[
           0 \approx F_g(x_-, u, u_+) =
           f_{x_+} (g_x (g_x \hat{x} + g_u u) + g_u u_+) +
           f_x (g_x \hat{x} + g_u u) +
           f_{x_-} \hat{x} + f_u u
           \]

        The Jacobian matrices :math:`f_{x_+}`, :math:`f_x`, :math:`f_{x_-}`, and :math:`f_u` are the matrices :math:`A`,
        :math:`B`, :math:`C`, and :math:`D` respectively, evaluated at the steady state, and are thus known. The task
        is then to solve for unknown matrices :math:`g_x` and :math:`g_u`, which will give a linear approximation to the
        optimal policy function.

        Take expectations, and impose that :math:`\mathbb{E}_t[u_+] = 0`:

        .. math::
           :nowrap:

           \begin{align}
           0 \approx {} &
           f_{x_+} (g_x(g_x \hat{x} + g_u u) + g_u \mathbb{E}_t[u_+]) +
           f_x (g_x \hat{x} + g_u u) + f_{x_-} \hat{x} + f_u u \\
           \approx {} &
           (f_{x_+} g_x g_x + f_x g_x + f_{x_-})\hat{x} +
           (f_{x_+} g_x g_u + f_x g_u + f_u) u
           \end{align}

        For the system to be equal to zero, both coefficient matrices must be zero, which gives us two linear equations
        in the unknowns :math:`g_x` and :math:`g_u`:

        .. math::
           :nowrap:

           \begin{align}
           (f_{x_+} g_x g_x + f_x g_x + f_{x_-}) \hat{x} &= 0 \\
           (f_{x_+} g_x g_u + f_x g_u + f_u) u &= 0
           \end{align}

        Assuming :math:`g_x` has been solved for, the coefficient in the second equation can be directly solved for,
        giving:

        .. math::
           :nowrap:

           \[
           g_u = -(f_{x_+} g_x + f_x)^{-1} f_u = 0
           \]

        The first equation, on the other hand, is a quadratic in :math:`g_x`, and cannot be solved for directly.
        Instead, we employ trickery. Then the equation can be re-written as a linear system in two states:

        .. math::
           :nowrap:

           \begin{align}
           \begin{bmatrix} 0 & f_{x_+} \\ I & 0 \end{bmatrix}
           \begin{bmatrix} g_x g_x \\ g_x \end{bmatrix} \hat{x}
           &=
           \begin{bmatrix} -f_x & -f_{x_-} \\ I & 0 \end{bmatrix}
           \begin{bmatrix} g_x \\ I \end{bmatrix} \hat{x} \\
           D \begin{bmatrix} I \\ g_x \end{bmatrix} g_x \hat{x}
           &=
           E \begin{matrix} g_x \\ I \end{matrix} \hat{x} \\
           QTZ \begin{bmatrix} I \\ g_x \end{bmatrix} g_x \hat{x}
           &=
           QSZ \begin{bmatrix} g_x \\ I \end{bmatrix} \hat{x} \\
           TZ \begin{bmatrix} I \\ g_x \end{bmatrix} g_x \hat{x}
           &=
           SZ \begin{bmatrix} g_x \\ I \end{bmatrix} \hat{x}
           \end{align}

        The last two lines use the QZ decomposition of the pencil :math:`<D, E>` into upper triangular matrix :math:`T`
        and quasi-upper triangular matrix :math:`S`, and the orthogonal matrices :math:`Z` and :math:`Q`. :math:`T` and
        :math:`S` have structure that can be exploited. In particular, they are arranged so that the eigenvalues of the
        pencil :math:`<D, E>` are sorted in modulus from smallest (stable) to largest (unstable).

        Partitioning the rows of the matrices by eign-stability, and the columns by the size of :math:`g_x`, we get:

        .. math::
           :nowrap:

           \[
           \begin{bmatrix} T_{11} & T_{12} \\ 0 & T_{22} \end{bmatrix}
           \begin{bmatrix} Z_{11} & Z_{12} \\ Z_{21} & Z_{22} \end{bmatrix}
           \begin{bmatrix} I \\ g_x \end{bmatrix} g_x \hat{x} =
           \begin{bmatrix} S_{11} & S_{12} \\ 0 & S_{22} \end{bmatrix}
           \begin{bmatrix} Z_{11} & Z_{12} \\ Z_{21} & Z_{22} \end{bmatrix}
           \begin{bmatrix} g_x \\ I \end{bmatrix} \hat{x}
           \]

        For the system to the stable, we require that:

        .. math::
           :nowrap:

           \[
           Z_{21} + Z_{22} g_x = 0
           \]

        And thus:

        .. math::
           :nowrap:

           \[
           g_x = -Z_{22}^{-1} Z_{21}
           \]

        This requires that -Z_{22} is square and invertible, which are known as the *rank* and *stability* conditions of
        Blanchard and Kahn (1980). If these conditions are not met, the model is indeterminate, and a solution is not
        possible.

        """
        if on_failure not in ["error", "ignore"]:
            raise ValueError(f'Parameter on_failure must be one of "error" or "ignore", found {on_failure}')
        if steady_state_kwargs is None:
            steady_state_kwargs = {}

        ss_dict = _maybe_solve_steady_state(self, steady_state, steady_state_kwargs, parameter_updates)
        n_variables = len(self.variables)

        A, B, C, D = self.linearize_model(
            order=order,
            log_linearize=log_linearize,
            not_loglin_variables=not_loglin_variables,
            steady_state=ss_dict.to_string(),
            loglin_negative_ss=loglin_negative_ss,
            verbose=verbose,
            **parameter_updates,
        )

        assert all(x.flags["C_CONTIGUOUS"] for x in [A, B, C, D])

        if solver == "gensys":
            gensys_results = solve_policy_function_with_gensys(A, B, C, D, tol)
            G_1, constant, impact, f_mat, f_wt, y_wt, gev, eu, loose = gensys_results

            success = all(x == 1 for x in eu[:2])

            if not success:
                if on_failure == "error":
                    raise GensysFailedException(eu)
                if on_failure == "ignore":
                    if verbose:
                        message = interpret_gensys_output(eu)
                        _log.info(message)

                    return None, None

            if verbose:
                message = interpret_gensys_output(eu)
                _log.info(message)
                _log.info("Policy matrices have been stored in attributes model.P, model.Q, model.R, and model.S")

            T = G_1[:n_variables, :][:, :n_variables]
            R = impact[:n_variables, :]

        elif solver == "cycle_reduction":
            (
                T,
                R,
                result,
                log_norm,
            ) = solve_policy_function_with_cycle_reduction(A, B, C, D, max_iter, tol, verbose)
            if T is None:
                if on_failure == "error":
                    raise GensysFailedException(result)
                if on_failure == "ignore":
                    if verbose:
                        _log.info(result)
                    return None, None
        else:
            raise NotImplementedError('Only "cycle_reduction" and "gensys" are valid values for solver')

        if verbose:
            check_perturbation_solution(A, B, C, D, T, R, tol=tol)

        return np.ascontiguousarray(T), np.ascontiguousarray(R)


def _maybe_solve_steady_state(
    model: Model,
    steady_state: dict | None,
    steady_state_kwargs: dict | None,
    parameter_updates: dict | None,
):
    if steady_state is None:
        if model.is_linear:
            return model.f_ss(**model.parameters(**parameter_updates))

        return model.steady_state(**model.parameters(**parameter_updates), **steady_state_kwargs)

    ss_resid = model.f_ss_resid(**steady_state, **model.parameters(**parameter_updates))
    FLOAT_ZERO = 1e-8
    unsatisfied_flags = np.abs(ss_resid) > FLOAT_ZERO
    unsatisfied_eqs = [f"Equation {i}" for i, flag in enumerate(unsatisfied_flags) if flag]

    if np.any(unsatisfied_flags):
        raise SteadyStateNotFoundError(unsatisfied_eqs)
    steady_state.success = True

    return steady_state


def _maybe_linearize_model(
    model: Model,
    A: np.ndarray | None,
    B: np.ndarray | None,
    C: np.ndarray | None,
    D: np.ndarray | None,
    **linearize_model_kwargs,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Linearize a model if required, or return the provided matrices.

    Parameters
    ----------
    model: Model
        DSGE model
    A: np.ndarray, optional
        Matrix of partial derivatives of model equations with respect to variables at time t-1, evaluated at the
        steady-state
    B: np.ndarray, optional
        Matrix of partial derivatives of model equations with respect to variables at time t, evaluated at the
        steady-state
    C: np.ndarray, optional
        Matrix of partial derivatives of model equations with respect to variables at time t+1, evaluated at the
        steady-state
    D: np.ndarray, optional
        Matrix of partial derivatives of model equations with respect to stochastic innovations, evaluated at the
        steady-state
    linearize_model_kwargs
        Arguments forwarded to the ``model.linearize_model`` method. Ignored if all of A, B, C, and D are provided.

    Returns
    -------
    linear_system: np.ndarray, np.ndarray, np.ndarray, np.ndarray
    """
    verbose = linearize_model_kwargs.get("verbose", True)
    n_matrices = sum(x is not None for x in [A, B, C, D])

    if n_matrices < 4 and n_matrices > 0 and verbose:
        _log.warning(
            f"Passing an incomplete subset of A, B, C, and D (you passed {n_matrices}) will still trigger "
            f"``model.linearize_model`` (which might be expensive). Pass all to avoid this, or None to silence "
            f"this warning."
        )
        A = None
        B = None
        C = None
        D = None

    if all(x is None for x in [A, B, C, D]):
        A, B, C, D = model.linearize_model(**linearize_model_kwargs)

    return A, B, C, D


def _maybe_solve_model(
    model: Model, T: np.ndarray | None, R: np.ndarray | None, **solve_model_kwargs
) -> tuple[np.ndarray, np.ndarray] | tuple[None, None]:
    """
    Solve for the linearized policy matrix of a model if required, or return the provided T and R.

    Parameters
    ----------
    model: Model
        DSGE Model assoicated with T and R
    T: np.ndarray, optional
        Transition matrix of the solved system. If None, this will be computed using the model's ``solve_model``
        method.
    R: np.ndarray
        Selection matrix of the solved system. If None, this will be computed using the model's ``solve_model`` method.
    **solve_model_kwargs
        Arguments forwarded to the ``solve_model`` method. Ignored if T and R are provided.

    Returns
    -------
    T: np.ndarray, optional
        Transition matrix, approximated to the requested order. Represents the policy function, governing agent's
        optimal state-conditional actions. If the solver fails, None is returned instead.

    R: np.ndarray, optional
        Selection matrix, approximated to the requested order. Represents the state- and agent-conditional
        transmission of stochastic shocks through the economy. If the solver fails, None is returned instead.
    """
    n_matrices = sum(x is not None for x in [T, R])
    if n_matrices == 1:
        _log.warning(
            "Passing only one of T or R will still trigger ``model.solve_model`` (which might be expensive). "
            "Pass both to avoid this, or None to silence this warning."
        )
        T = None
        R = None

    if T is None and R is None:
        T, R = model.solve_model(**solve_model_kwargs)

    return T, R


def _validate_shock_options(
    shock_std_dict: dict[str, float] | None,
    shock_cov_matrix: np.ndarray | None,
    shock_std: float | np.ndarray | list | None,
    shocks: list[TimeAwareSymbol],
):
    n_shocks = len(shocks)
    n_provided = sum(x is not None for x in [shock_std_dict, shock_cov_matrix, shock_std])
    if n_provided > 1 or n_provided == 0:
        raise ValueError(
            "Exactly one of shock_std_dict, shock_cov_matrix, or shock_std should be provided. You passed "
            f"{n_provided}."
        )

    if shock_cov_matrix is not None and any(s != n_shocks for s in shock_cov_matrix.shape):
        raise ValueError(
            f"Incorrect covariance matrix shape. Expected ({n_shocks}, {n_shocks}), found {shock_cov_matrix.shape}"
        )

    if shock_std_dict is not None:
        shock_names = [x.base_name for x in shocks]
        missing = [x for x in shock_std_dict if x not in shock_names]
        extra = [x for x in shock_names if x not in shock_std_dict]
        if len(missing) > 0:
            raise ValueError(
                f"If shock_std_dict is specified, it must give values for all shocks. The following shocks were not "
                f"found among the provided keys: {', '.join(missing)}"
            )
        if len(extra) > 0:
            raise ValueError(
                f"Unexpected shocks in shock_std_dict. The following names were not found among the model shocks: "
                f"{', '.join(extra)}"
            )

    if shock_std is not None:
        if isinstance(shock_std, np.ndarray | list):
            shock_std = cast(np.ndarray | list, shock_std)
            if len(shock_std) != n_shocks:
                raise ValueError(
                    f"Length of shock_std ({len(shock_std)}) does not match the number of shocks ({n_shocks})"
                )
            if not np.all(shock_std > 0):
                raise ValueError("Shock standard deviations must be positive")
        elif isinstance(shock_std, int | float):
            if shock_std < 0:
                raise ValueError("Shock standard deviation must be positive")


def _validate_simulation_options(shock_size, shock_cov, shock_trajectory) -> None:
    options = [shock_size, shock_cov, shock_trajectory]
    n_options = sum(x is not None for x in options)

    if n_options != 1:
        raise ValueError("Specify exactly 1 of shock_size, shock_cov, or shock_trajectory")


def build_Q_matrix(
    model_shocks: list[TimeAwareSymbol],
    shock_std_dict: dict[str, float] | None = None,
    shock_cov_matrix: np.ndarray | None = None,
    shock_std: np.ndarray | list | float | None = None,
) -> np.array:
    """
    Take different options for user input and reconcile them into a covariance matrix.

    Exactly one or zero of shock_dict or shock_cov_matrix should be provided. Then, proceed according to the following
    logic:

    - If `shock_cov_matrix` is provided, it is Q. Return it.
    - If `shock_dict` is provided, insert these into a diagonal matrix at locations according to `model_shocks`.

    For values missing from `shock_dict`, or if neither `shock_dict` nor `shock_cov_matrix` are provided:

    - Fill missing values using the mean of the prior defined in `shock_priors`
    - If no prior is set, fill the value with `default_value`.

    Note that the only way to get off-diagonal elements is to explicitly pass the entire covariance matrix.

    Parameters
    ----------
    model_shocks: list of str
        List of model shock names, used to infer positions in the covariance matrix
    shock_std_dict: dict, optional
        Dictionary of shock names and standard deviations to be used to build Q
    shock_cov_matrix: array, optional
        An (n_shocks, n_shocks) covariance matrix describing the exogenous shocks
    shock_std: float or sequence of float, optional
        Standard deviation of all model shocks. If float, the same value will be used for all shocks. If sequence, the
        length must match the number of shocks.

    Raises
    ------
    LinalgError
        If the provided Q is not positive semi-definite
    ValueError
        If both model_shocks and shock_dict are provided

    Returns
    -------
    Q: ndarray
        Shock variance-covariance matrix
    """
    _validate_shock_options(
        shock_std_dict=shock_std_dict,
        shock_cov_matrix=shock_cov_matrix,
        shock_std=shock_std,
        shocks=model_shocks,
    )

    if shock_cov_matrix is not None:
        return shock_cov_matrix

    if shock_std_dict is not None:
        shock_names = [x.base_name for x in model_shocks]
        indices = [shock_names.index(x) for x in shock_std_dict]
        Q = np.zeros((len(model_shocks), len(model_shocks)))
        for i, (_key, value) in enumerate(shock_std_dict.items()):
            Q[indices[i], indices[i]] = value**2
        return Q

    return np.eye(len(model_shocks)) * shock_std**2


def stationary_covariance_matrix(
    model: Model,
    T: np.ndarray | None = None,
    R: np.ndarray | None = None,
    shock_std_dict: dict[str, float] | None = None,
    shock_cov_matrix: np.ndarray | None = None,
    shock_std: np.ndarray | list | float | None = None,
    return_df: bool = True,
    **solve_model_kwargs,
) -> np.ndarray | pd.DataFrame:
    """
    Compute the stationary covariance matrix of the solved system.

    Solution is found by solving the associated discrete lyapunov equation.

    In order to construct the shock covariance matrix, exactly one of shock_dict, shock_cov_matrix, or shock_std should
    be provided.

    Parameters
    ----------
    model: Model
        DSGE Model assoicated with T and R
    T: np.ndarray, optional
        Transition matrix of the solved system. If None, this will be computed using the model's ``solve_model``
        method.
    R: np.ndarray
        Selection matrix of the solved system. If None, this will be computed using the model's ``solve_model`` method.
    shock_std_dict: dict, optional
        A dictionary of shock sizes to be used to compute the stationary covariance matrix.
    shock_cov_matrix: array, optional
        An (n_shocks, n_shocks) covariance matrix describing the exogenous shocks
    shock_std: float, optional
        Standard deviation of all model shocks.
    return_df: bool
        If True, return the covariance matrix as a DataFrame
    **solve_model_kwargs
        Arguments forwarded to the ``solve_model`` method. Ignored if T and R are provided.

    Returns
    -------
    Sigma: np.ndarray | pd.DataFrame
        Stationary covariance matrix of the linearized model. Datatype depends on the variable of the ``return_df``
        argument.
    """
    shocks = model.shocks
    _validate_shock_options(
        shock_std_dict=shock_std_dict,
        shock_cov_matrix=shock_cov_matrix,
        shock_std=shock_std,
        shocks=shocks,
    )

    T, R = _maybe_solve_model(model, T, R, **solve_model_kwargs)

    Q = build_Q_matrix(
        model_shocks=shocks,
        shock_std_dict=shock_std_dict,
        shock_cov_matrix=shock_cov_matrix,
        shock_std=shock_std,
    )

    RQRT = np.linalg.multi_dot([R, Q, R.T])
    Sigma = linalg.solve_discrete_lyapunov(T, RQRT)

    if return_df:
        variables = [x.base_name for x in model.variables]
        Sigma = pd.DataFrame(Sigma, index=variables, columns=variables)

    return Sigma


def check_bk_condition(
    model: Model,
    *,
    A: np.ndarray | None = None,
    B: np.ndarray | None = None,
    C: np.ndarray | None = None,
    D: np.ndarray | None = None,
    tol=1e-8,
    on_failure: Literal["raise", "ignore"] = "ignore",
    return_value: Literal["dataframe", "bool", None] = "dataframe",
    **linearize_model_kwargs,
) -> bool | pd.DataFrame | None:
    """
    Compute the generalized eigenvalues of system in the form presented in [1].

    Per [2], the number of unstable eigenvalues (:math:`|v| > 1`) should not be greater than the number of
    forward-looking variables. Failing this test suggests timing problems in the definition of the model.

    Parameters
    ----------
    model: Model
        DSGE model
    A: np.ndarray
        Jacobian matrix of the DSGE system, evaluated at the steady state, taken with respect to past variables
        values that are known when decision-making: those with t-1 subscripts.
    B: np.ndarray
        Jacobian matrix of the DSGE system, evaluated at the steady state, taken with respect to variables that
        are observed when decision-making: those with t subscripts.
    C: np.ndarray
        Jacobian matrix of the DSGE system, evaluated at the steady state, taken with respect to variables that
        enter in expectation when decision-making: those with t+1 subscripts.
    D: np.ndarray
        Jacobian matrix of the DSGE system, evaluated at the steady state, taken with respect to exogenous shocks.
    verbose: bool, default: True
        Flag to print the results of the test, otherwise the eigenvalues are returned without comment.
    on_failure: str, default: 'ignore'
        Action to take if the Blanchard-Kahn condition is not satisfied. Valid values are 'ignore' and 'raise'.
    return_value: string, default: 'dataframe'
        Controls what is returned by the function. Valid values are 'dataframe', 'bool', and 'none'.
        If df, a dataframe containing eigenvalues is returned. If 'bool', a boolean indicating whether the BK
        condition is satisfied. If None, nothing is returned.
    tol: float, 1e-8
        Tolerance below which numerical values are considered zero

    Returns
    -------
    bk_result, bool or pd.DataFrame, optional.
        Return value requested. Datatype corresponds to what was requested in the ``return_value`` argument:

        - None, If return_value is 'none'
        - condition_satisfied, bool;  if return_value is 'bool', returns True if the Blanchard-Kahn condition is
          satisfied, False otherwise.
        - Eigenvalues, pd.DataFrame, if return_value is 'df', returns a dataframe containing the real and imaginary
          components of the system's eigenvalues, along with their modulus.
    """
    verbose = linearize_model_kwargs.get("verbose", True)
    A, B, C, D = _maybe_linearize_model(model, A, B, C, D, **linearize_model_kwargs)
    return _check_bk_condition(
        A,
        B,
        C,
        D,
        tol=tol,
        verbose=verbose,
        on_failure=on_failure,
        return_value=return_value,
    )


# @nb.njit(cache=True)
def _compute_autocovariance_matrix(T, Sigma, n_lags=5, correlation=True):
    """Compute the autocorrelation matrix for the given state-space model.

    Parameters
    ----------
    T: np.ndarray, optional
        Transition matrix of the solved system.
    Sigma: np.ndarray
        Stationary covariance matrix of the linearized model
    n_lags : int, optional
        The number of lags for which to compute the autocorrelation matrices.
    correlation: bool
        If True, return the autocorrelation matrices instead of the autocovariance matrices.

    Returns
    -------
    acov : ndarray
        An array of shape (n_lags, n_variables, n_variables) whose (i, j, k)-th entry gives the autocorrelation
        (or autocovaraince) between variables j and k at lag i.
    """
    n_vars = T.shape[0]
    auto_coors = np.empty((n_lags, n_vars, n_vars))
    std_vec = np.sqrt(np.diag(Sigma))

    normalization_factor = np.outer(std_vec, std_vec) if correlation else np.ones_like(Sigma)

    for i in range(n_lags):
        auto_coors[i] = np.linalg.matrix_power(T, i) @ Sigma / normalization_factor

    return auto_coors


def autocovariance_matrix(
    model: Model,
    T: np.ndarray | None = None,
    R: np.ndarray | None = None,
    shock_std_dict: dict[str, float] | None = None,
    shock_cov_matrix: np.ndarray | None = None,
    shock_std: np.ndarray | list | float | None = None,
    n_lags: int = 10,
    correlation=False,
    return_xr=True,
    **solve_model_kwargs,
):
    """
    Compute the model's autocovariance matrix using the stationary covariance matrix.

    Alternatively, the autocorrelation matrix can be returned by specifying ``correlation = True``.

    In order to construct the shock covariance matrix, exactly one of shock_dict, shock_cov_matrix, or shock_std should
    be provided.

    Parameters
    ----------
    model: Model
        DSGE Model associated with T and R
    T: np.ndarray, optional
        Transition matrix of the solved system. If None, this will be computed using the model's ``solve_model``
        method.
    R: np.ndarray
        Selection matrix of the solved system. If None, this will be computed using the model's ``solve_model`` method.
    shock_std_dict: dict, optional
        A dictionary of shock sizes to be used to compute the stationary covariance matrix.
    shock_cov_matrix: array, optional
        An (n_shocks, n_shocks) covariance matrix describing the exogenous shocks
    shock_std: float, optional
        Standard deviation of all model shocks.
    n_lags: int
        Number of lags of auto-covariance and cross-covariance to compute. Default is 10.
    correlation: bool
        If True, return the autocorrelation matrices instead of the autocovariance matrices.
    return_xr: bool
        If True, return the covariance matrices as a DataArray with dimensions ["variable", "variable_aux", and "lag"].
        Otherwise returns a 3d numpy array with shape (lag, variable, variable).
    **solve_model_kwargs
        Arguments forwarded to the ``solve_model`` method. Ignored if T and R are provided.

    Returns
    -------
    acorr_mat: DataFrame
    """
    T, R = _maybe_solve_model(model, T, R, **solve_model_kwargs)

    Sigma = stationary_covariance_matrix(
        model,
        T=T,
        R=R,
        shock_dict=shock_std_dict,
        shock_cov_matrix=shock_cov_matrix,
        shock_std=shock_std,
        return_df=False,
        **solve_model_kwargs,
    )
    result = _compute_autocovariance_matrix(T, Sigma, n_lags=n_lags, correlation=correlation)

    if return_xr:
        variables = [x.base_name for x in model.variables]
        result = xr.DataArray(
            result,
            dims=["lag", "variable", "variable_aux"],
            coords={
                "lag": range(n_lags),
                "variable": variables,
                "variable_aux": variables,
            },
        )

    return result


def summarize_perturbation_solution(
    linear_system: Sequence[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
    perturbation_solution: Sequence[np.ndarray | None, np.ndarray | None],
    model: Model,
):
    A, B, C, D = linear_system
    T, R = perturbation_solution
    if T is None or R is None:
        raise PerturbationSolutionNotFoundException()

    coords = {
        "equation": np.arange(A.shape[0]).astype(int),
        "variable": [x.base_name for x in model.variables],
        "shock": [x.base_name for x in model.shocks],
    }

    return xr.Dataset(
        data_vars={
            "A": (("equation", "variable"), A),
            "B": (("equation", "variable"), B),
            "C": (("equation", "variable"), C),
            "D": (("equation", "shock"), D),
            "T": (("equation", "variable"), T),
            "R": (("equation", "shock"), R),
        },
        coords=coords,
    )


autocorrelation_matrix = ft.partial(autocovariance_matrix, correlation=True)
autocorrelation_matrix.__doc__ = autocovariance_matrix.__doc__


@dataclass(frozen=True)
class ShockSpec:
    """Representation of a shock input used to generate an impulse response function."""

    mode: str  # one of "trajectory",  "cov", or "size"
    trajectory: np.ndarray | None
    cov: np.ndarray | None
    size: float | np.ndarray | dict[str, float] | None
    orthogonalize: bool


def _validate_irf_shock_arguments(*values_with_names: tuple[str, Any]) -> None:
    """Ensure at most one of the provided options is non-None."""
    provided_names, provided_values = zip(*[(n, v) for n, v in values_with_names if v is not None], strict=False)
    if len(provided_names) > 1:
        names = ", ".join(n for n, _ in provided_names)
        raise ValueError(f"Only one of {names} may be specified, got {len(provided_names)}.")


def _is_diagonal(M: np.ndarray) -> bool:
    return np.allclose(M, np.diag(np.diag(M)))


def _get_selected_shock_names(spec: ShockSpec, shock_names: list[str]) -> list[str]:
    """
    If user passed a dict for shock_size, return only those shock names (in model order).

    Otherwise, return None to indicate all shocks should be used.
    """
    if spec.mode == "size" and isinstance(spec.size, dict):
        if len(spec.size) == 0:
            raise ValueError("Shock size cannot be empty.")

        unknown_shocks = set(spec.size) - set(shock_names)
        if unknown_shocks:
            raise ValueError(f"shock_size dict contains unknown shock names: {unknown_shocks}")

        return [name for name in shock_names if name in spec.size]

    return shock_names


def _infer_shocks_are_individual(
    requested: bool | None,
    shock_spec: ShockSpec,
    n_shocks: int,
) -> bool:
    # If the user specifically asked for individual shocks (or not), respect it
    if requested is not None:
        return requested

    if shock_spec.mode == "size":
        # scalar, dict, or diagonal vector are all treated as per-shock steps
        if isinstance(shock_spec.size, int | float | dict):
            return True

        arr = np.asarray(shock_spec.size)
        return arr.ndim == 0 or arr.shape == (n_shocks,)

    if shock_spec.mode == "cov":
        return _is_diagonal(np.asarray(shock_spec.cov))

    if shock_spec.mode == "trajectory":
        # If a full trajectory is given, default to a single combined IRF unless user asks otherwise
        return False

    return False


def _make_shock_spec(
    shock_size: float | np.ndarray | dict[str, float] | None,
    shock_cov: np.ndarray | None,
    shock_trajectory: np.ndarray | None,
    orthogonalize_shocks: bool,
) -> ShockSpec:
    _validate_irf_shock_arguments(
        ("shock_size", shock_size),
        ("shock_cov", shock_cov),
        ("shock_trajectory", shock_trajectory),
    )

    mode = "trajectory" if shock_trajectory is not None else "cov" if shock_cov is not None else "size"
    return ShockSpec(
        mode=mode, trajectory=shock_trajectory, cov=shock_cov, size=shock_size, orthogonalize=orthogonalize_shocks
    )


def _shock_vector_from_spec(
    size: float | np.ndarray | dict[str, float] | None,
    shock_names: Sequence[str],
) -> np.ndarray:
    """Return an (n_shocks,) vector for a one-period step size."""
    n = len(shock_names)
    if size is None:
        return np.ones(n)
    if isinstance(size, int | float):
        return np.full(n, float(size))
    if isinstance(size, dict):
        return np.array([float(size.get(name, 0.0)) for name in shock_names], dtype=float)
    arr = np.asarray(size, dtype=float)
    if arr.shape != (n,):
        raise ValueError(f"shock_size array must have shape ({n},); got {arr.shape}.")
    return arr


def _orthogonal_factor(cov: np.ndarray, make_unit_variance: bool = False) -> np.ndarray:
    """
    Compute L such that z ~ N(0, I), e = L @ z has Cov(e)=cov.

    If make_unit_variance=True, return L' whose columns are scaled to unit variance (orthonormal shocks).
    """
    L = np.linalg.cholesky(cov)  # lower-triangular
    if not make_unit_variance:
        return L

    col_norms = np.linalg.norm(L, axis=0)
    col_norms[col_norms == 0] = 1.0
    return L / col_norms


def _build_trajectory(
    spec: ShockSpec,
    simulation_length: int,
    n_shocks: int,
    shock_names: Sequence[str],
    rng: np.random.Generator,
) -> np.ndarray:
    """Convert a ShockSpec into a (simulation_length, n_shocks) shock trajectory."""
    match spec.mode:
        case "trajectory":
            traj = np.asarray(spec.trajectory, dtype=float)
            if traj.ndim != 2 or traj.shape[1] != n_shocks:
                raise ValueError(f"shock_trajectory must have shape (T, {n_shocks}); got {traj.shape}.")

        case "cov":
            traj = np.zeros((simulation_length, n_shocks), dtype=float)

            Q = np.asarray(spec.cov, dtype=float)
            if Q.shape != (n_shocks, n_shocks):
                raise ValueError(f"shock_cov must be ({n_shocks}, {n_shocks}); got {Q.shape}.")
            L = _orthogonal_factor(Q, make_unit_variance=False) if spec.orthogonalize else np.linalg.cholesky(Q)
            e0 = rng.standard_normal(n_shocks)
            traj[0] = L @ e0

        case "size":
            traj = np.zeros((simulation_length, n_shocks), dtype=float)
            shock_size = _shock_vector_from_spec(spec.size, shock_names)
            traj[0] = shock_size

        case _:
            raise RuntimeError(f"Unexpected ShockSpec mode: {spec.mode}. You shouldn't get here, please report a bug.")

    return traj


def _simulate_linear_system(T: np.ndarray, R: np.ndarray, shock_traj: np.ndarray) -> np.ndarray:
    """Simulate a linear system :math:`x_t = T x_{t-1} + R e_t`, given a shock trajectory :math:`e_t`."""
    T = np.asarray(T)
    R = np.asarray(R)
    T_len, n_shocks = shock_traj.shape
    n_vars = T.shape[0]

    out = np.zeros((T_len, n_vars), dtype=float)
    for t in range(1, T_len):
        out[t] = T @ out[t - 1] + R @ shock_traj[t - 1]
    return out


def _irf_to_xarray(
    data: np.ndarray,
    variable_names: list[str],
    shock_names: list[str] | None,
) -> xr.DataArray:
    if shock_names is None:
        coords = {"time": np.arange(data.shape[0]), "variable": list(variable_names)}
        return xr.DataArray(data, dims=["time", "variable"], coords=coords)
    coords = {
        "shock": list(shock_names),
        "time": np.arange(data.shape[1]),
        "variable": list(variable_names),
    }
    return xr.DataArray(data, dims=["shock", "time", "variable"], coords=coords)


def impulse_response_function(
    model: Model,
    T: np.ndarray | None = None,
    R: np.ndarray | None = None,
    simulation_length: int = 40,
    shock_size: float | np.ndarray | dict[str, float] | None = None,
    shock_cov: np.ndarray | None = None,
    shock_trajectory: np.ndarray | None = None,
    return_individual_shocks: bool | None = None,
    orthogonalize_shocks: bool = False,
    random_seed: int | np.random.RandomState | None = None,
    **solve_model_kwargs,
) -> xr.DataArray:
    """
    Generate impulse response functions (IRF) from state space model dynamics.

    An impulse response function represents the dynamic response of the state space model
    to an instantaneous shock applied to the system. This function calculates the IRF
    based on either provided shock specifications or the posterior state covariance matrix.

    Parameters
    ----------
    model: Model
        DSGE Model object
    T: np.ndarray, optional
        Transition matrix of the solved system. If None, this will be computed using the model's ``solve_model``
        method.
    R: np.ndarray, optional
        Selection matrix of the solved system. If None, this will be computed using the model's ``solve_model`` method.
    simulation_length : int, optional
        The number of periods to compute the IRFs over. The default is 40.
    shock_size : float, array, or dict; default=None
        The size of the shock applied to the system. If specified, it will create a covariance
        matrix for the shock with diagonal elements equal to `shock_size`:

            - If float, the covariance matrix will be the identity matrix scaled by `shock_size`.
            - If array, the covariance matrix will be ``diag(shock_size)``. In this case, the length of the
              provided array must match the number of shocks in the state space model.
            - If dictionary, a diagonal matrix will be created with entries corresponding to the keys in the
              dictionary. Shocks that are not specified will be set to zero.

        Only one of `use_stationary_cov`, `shock_cov`, `shock_size`, or `shock_trajectory` can be specified.
    shock_cov : Optional[np.ndarray], default=None
        A user-specified covariance matrix for the shocks. It should be a 2D numpy array with
        dimensions (n_shocks, n_shocks), where n_shocks is the number of shocks in the state space model.

        Only one of `use_stationary_cov`, `shock_cov`, `shock_size`, or `shock_trajectory` can be specified.
    shock_trajectory : Optional[np.ndarray], default=None
        A pre-defined trajectory of shocks applied to the system. It should be a 2D numpy array
        with dimensions (n, n_shocks), where n is the number of time steps and k_posdef is the
        number of shocks in the state space model.

        Only one of `use_stationary_cov`, `shock_cov`, `shock_size`, or `shock_trajectory` can be specified.
    return_individual_shocks: bool, optional
        If True, an IRF will be computed separately for each shock in the model. An additional dimension will be added
        to the output DataArray to show each shock. This is only valid if `shock_size` is a scalar, dictionary, or if
        the covariance matrix is diagonal.

        If not specified, this will be set to True if ``shock_size`` if the above conditions are met.
    orthogonalize_shocks : bool, default=False
        If True, orthogonalize the shocks using Cholesky decomposition when generating the impulse
        response. This option is ignored if `shock_trajectory` or `shock_size` are used, or if the covariance matrix is
        diagonal.
    random_seed : int, RandomState or Generator, optional
        Seed for the random number generator.
    **solve_model_kwargs
        Arguments forwarded to the ``solve_model`` method. Ignored if T and R are provided.

    Returns
    -------
    xr.DataArray
        The IRFs for each variable in the model.
    """
    rng = np.random.default_rng(random_seed)
    T, R = _maybe_solve_model(model, T, R, **solve_model_kwargs)

    spec = _make_shock_spec(shock_size, shock_cov, shock_trajectory, orthogonalize_shocks)

    variable_names = [x.base_name for x in model.variables]
    shock_names = [x.base_name for x in model.shocks]
    selected_shock_names = _get_selected_shock_names(spec, [x.base_name for x in model.shocks])

    n_vars = len(variable_names)
    n_shocks = len(model.shocks)
    n_selected_shocks = len(selected_shock_names)

    shock_idxs = [i for i, x in enumerate(model.shocks) if x.base_name in selected_shock_names]

    if spec.mode == "trajectory":
        simulation_length = spec.trajectory.shape[0]

    apply_shocks_individually = _infer_shocks_are_individual(return_individual_shocks, spec, n_selected_shocks)

    if apply_shocks_individually:
        # Build and simulate one IRF per shock axis
        data = np.zeros((n_selected_shocks, simulation_length, n_vars), dtype=float)

        if spec.mode == "trajectory":
            full = np.asarray(spec.trajectory, dtype=float)
            if full.shape[1] != n_shocks:
                raise ValueError(f"shock_trajectory must have n_shocks={n_shocks}.")
            for i, idx in enumerate(shock_idxs):
                traj = np.zeros_like(full)
                traj[:, idx] = full[:, idx]
                data[i] = _simulate_linear_system(T, R, traj)
        else:
            # Regardless of how many shocks were chosen by the user, we normalize to a full trajectory for everything
            # first, then subset only what was requested.
            base = _build_trajectory(spec, simulation_length, n_shocks, shock_names, rng)
            for i, idx in enumerate(shock_idxs):
                traj = np.zeros_like(base)
                traj[:, idx] = base[:, idx]
                data[i] = _simulate_linear_system(T, R, traj)

        return _irf_to_xarray(data, variable_names, shock_names=selected_shock_names)

    traj = _build_trajectory(spec, simulation_length, n_shocks, selected_shock_names, rng)
    data = _simulate_linear_system(T, R, traj)
    return _irf_to_xarray(data, variable_names, shock_names=None)


def simulate(
    model: Model,
    T: np.ndarray | None = None,
    R: np.ndarray | None = None,
    n_simulations: int = 1,
    simulation_length: int = 40,
    shock_std_dict: dict[str, float] | None = None,
    shock_cov_matrix: np.ndarray | None = None,
    shock_std: np.ndarray | list | float | np.ndarray = None,
    random_seed: int | np.random.RandomState | None = None,
    **solve_model_kwargs,
) -> xr.DataArray:
    """
    Simulate the model over a certain number of time periods.

    Parameters
    ----------
    model: Model
        DSGE Model object
    T: np.ndarray, optional
        Transition matrix of the solved system. If None, this will be computed using the model's ``solve_model``
        method. Ignored if ``use_param_priors`` is True.
    R: np.ndarray, optional
        Selection matrix of the solved system. If None, this will be computed using the model's ``solve_model`` method.
        Ignored if ``use_param_priors`` is True.
    use_param_priors: bool, optional
        If True, each simulation will be generated using a different random draw from the model's
        prior distributions. Default is False, in which case a fixed T and R matrix will be used for each simulation.
        If True, T and R are ignored.
    n_simulations : int, optional
        Number of trajectories to simulate. Default is 1.
    simulation_length : int, optional
        Length of each simulated trajectory. Default is 40.
    shock_std_dict: dict, optional
        Dictionary of shock names and standard deviations to be used to build Q
    shock_cov_matrix: array, optional
        An (n_shocks, n_shocks) covariance matrix describing the exogenous shocks
    shock_std: float or sequence, optional
        Standard deviation of all model shocks.
    random_seed : int, RandomState or Generator, optional
        Seed for the random number generator.
    **solve_model_kwargs
        Arguments forwarded to the ``solve_model`` method. Ignored if T and R are provided.

    Returns
    -------
    xr.DataArray
        Simulated trajectories.
    """
    rng = np.random.default_rng(random_seed)
    shocks = model.shocks
    n_shocks = len(shocks)

    _validate_shock_options(
        shock_std_dict=shock_std_dict,
        shock_cov_matrix=shock_cov_matrix,
        shock_std=shock_std,
        shocks=shocks,
    )

    Q = build_Q_matrix(
        model_shocks=shocks,
        shock_std_dict=shock_std_dict,
        shock_cov_matrix=shock_cov_matrix,
        shock_std=shock_std,
    )

    epsilons = rng.multivariate_normal(
        mean=np.zeros(n_shocks),
        cov=Q,
        size=(n_simulations, simulation_length),
        method="svd",
    )

    data = np.zeros((n_simulations, simulation_length, len(model.variables)))
    T, R = _maybe_solve_model(model, T, R, **solve_model_kwargs)

    for t in range(1, simulation_length):
        stochastic = np.einsum("nk,sk->sn", R, epsilons[:, t - 1, :])
        deterministic = np.einsum("nm,sm->sn", T, data[:, t - 1, :])
        data[:, t, :] = deterministic + stochastic

    return xr.DataArray(
        data,
        dims=["simulation", "time", "variable"],
        coords={
            "variable": [x.base_name for x in model.variables],
            "simulation": np.arange(n_simulations),
            "time": np.arange(simulation_length),
        },
    )


def matrix_to_dataframe(
    matrix,
    model,
    dim1: str | None = None,
    dim2: str | None = None,
    round: None | int = None,
) -> pd.DataFrame:
    """
    Convert a matrix to a DataFrame with variable names as columns and rows.

    Parameters
    ----------
    matrix: np.ndarray
        DSGE matrix to convert to a DataFrame. Each dimension should have shape n_variables or n_shocks.
    model: Model
        DSGE model object
    dim1: str, Optional
        Name of the first dimension of the matrix. Must be one of "variable", "equation",  or "shock". If None, the
        function will guess based on the shape of the matrix. In the event that the model has exactly as many
        variables as shocks, it will guess "variable", so be careful!
    dim2: str, Optional
        Name of the second dimension of the matrix. Must be one of "variable", "equation", or "shock". If None, the
        function will guess based on the shape of the matrix.
    round: int, Optional
        Number of decimal places to round the values in the DataFrame. If None, values will not be rounded.

    Returns
    -------
    pd.DataFrame
        DataFrame with variable names as columns and rows.
    """
    var_names = [x.base_name for x in model.variables]
    shock_names = [x.base_name for x in model.shocks]
    equation_names = [f"Equation {i}" for i in range(len(model.equations))]

    coords = {"variable": var_names, "shock": shock_names, "equation": equation_names}

    n_variables = len(var_names)
    n_shocks = len(shock_names)

    if matrix.ndim != 2:
        raise ValueError("Matrix must be 2-dimensional")

    for i, ordinal in enumerate(["First", "Secoond"]):
        if matrix.shape[i] not in [n_variables, n_shocks]:
            raise ValueError(
                f"{ordinal} dimension of the matrix must match the number of variables or shocks in the model"
            )

    if dim1 is None:
        dim1 = "variable" if matrix.shape[0] == n_variables else "shock"
    if dim2 is None:
        dim2 = "variable" if matrix.shape[1] == n_variables else "shock"

    df = pd.DataFrame(
        matrix,
        index=coords[dim1],
        columns=coords[dim2],
    )

    if round is not None:
        return df.round(round)

    return df


def check_steady_state(
    model: Model,
    stead_state: SteadyStateResults | None = None,
    steady_state_kwargs: dict | None = None,
    **parameter_updates,
) -> None:
    if steady_state_kwargs is None:
        steady_state_kwargs = {}

    ss_dict = _maybe_solve_steady_state(model, stead_state, steady_state_kwargs, parameter_updates)
    if ss_dict.success:
        _log.warning("Steady state successfully found!")
        return

    parameters = model.parameters(**parameter_updates)
    residuals = model.f_ss_resid(**ss_dict, **parameters)
    _log.warning("Steady state NOT successful. The following equations have non-zero residuals:")

    FLOAT_ZERO = 1e-8
    for resid, eq in zip(residuals, model.equations, strict=False):
        if np.abs(resid) > FLOAT_ZERO:
            _log.warning(eq)
            _log.warning(f"Residual: {resid:0.4f}")
