"""
Advanced analytical functions for time series analysis.

This module contains functions that extend the capabilities of the
``analysis3054`` package beyond plotting and basic machine‑learning
forecasts.  These functions leverage concepts from physics,
mathematics and finance to provide additional analytical tools that
can be applied to energy market data or other time series.

Currently included are:

* :func:`harmonic_forecast` – A Fourier/harmonic regression model for
  forecasting periodic time series, suitable for capturing cyclical
  patterns such as seasonal demand or supply fluctuations.  It fits
  sine and cosine harmonics to each series and projects the pattern
  into the future.

* :func:`ewma_volatility` – Compute the exponentially weighted moving
  average (EWMA) volatility of each numeric series, a common
  technique in finance for estimating the volatility of asset returns.

Additional functions may be added in the future to support more
sophisticated analyses.
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Dict, Iterable, List, Optional, Tuple, Union

import numpy as np
import pandas as pd


@dataclass
class HarmonicForecastResult:
    """Result container for :func:`harmonic_forecast`.

    Attributes
    ----------
    forecasts : pandas.DataFrame
        A DataFrame containing the forecast values for each numeric
        column in the input DataFrame.  The index corresponds to the
        forecast dates and the columns match those of the input (minus
        the date column).
    coefficients : Dict[str, np.ndarray]
        A mapping from column name to the array of fitted regression
        coefficients (intercept, optional trend and sine/cosine terms).
    """

    forecasts: pd.DataFrame
    coefficients: Dict[str, np.ndarray]


def harmonic_forecast(
    date: Union[str, pd.Series, Iterable],
    df: pd.DataFrame,
    *,
    periods: int = 12,
    n_harmonics: int = 3,
    include_trend: bool = True,
    freq: Optional[str] = None,
) -> HarmonicForecastResult:
    """Forecast future values of each numeric series using harmonic regression.

    Harmonic (Fourier) regression models a time series as a sum of
    sine and cosine terms at integer multiples of a base period, plus
    optional linear trend and intercept components.  This method is
    particularly well suited to periodic phenomena, such as seasonal
    demand cycles.  The frequency (period) of the series is either
    inferred from the input dates or specified directly via ``freq``.

    Parameters
    ----------
    date : Union[str, pandas.Series, Iterable]
        Column name, series, or iterable containing date information for
        each observation.  If a string is provided, it must be a column
        in ``df``.  Dates are converted to pandas ``datetime64``.
    df : pandas.DataFrame
        DataFrame containing the data columns.  Numeric columns will
        each be modelled independently.  Non‑numeric columns (other
        than ``date``) are ignored.
    periods : int, default 12
        Number of future periods to forecast for each series.  The
        forecast horizon is expressed in the same frequency as the
        input data.  For example, if the data are weekly, a value of
        12 produces forecasts 12 weeks ahead.
    n_harmonics : int, default 3
        Number of harmonic pairs (sine and cosine) to include in the
        regression model.  Increasing this value allows the model to
        capture more complex periodic behaviour, at the cost of more
        parameters.
    include_trend : bool, default True
        Whether to include a linear trend term in addition to the
        periodic components.  A trend term can account for slow upward
        or downward drift in the series.
    freq : str or None, default None
        Optional string specifying the frequency of the input data.  If
        ``None``, an attempt is made to infer the frequency using
        ``pandas.infer_freq``.  Recognised frequencies (e.g.
        ``'W'``, ``'M'``) are used to determine the base period for the
        harmonics and to generate forecast dates.  If the frequency
        cannot be inferred, the period is set to the length of the
        series and future dates are generated by adding the mode of
        observed date differences.

    Returns
    -------
    HarmonicForecastResult
        A dataclass containing the forecast DataFrame and the
        coefficients for each modelled series.

    Notes
    -----
    Harmonic regression is a linear technique: the fitted coefficients
    are obtained via least squares.  While flexible and interpretable,
    it assumes the underlying periodicity is constant over time.  In
    practice, make sure the data exhibit reasonably consistent cycles
    before relying heavily on the forecasts.
    """
    # Determine the date series
    if isinstance(date, str):
        if date not in df.columns:
            raise KeyError(f"Date column '{date}' not found in DataFrame")
        date_series = df[date]
    else:
        date_series = pd.Series(date)
    dt = pd.to_datetime(date_series)
    if dt.empty:
        raise ValueError("Date series is empty")

    # Identify numeric columns (exclude the date column)
    numeric_cols: List[str] = []
    for col in df.columns:
        if col == date:
            continue
        if pd.api.types.is_numeric_dtype(df[col]):
            numeric_cols.append(col)
    if not numeric_cols:
        raise ValueError("No numeric data columns found for harmonic forecast")

    # Attempt to infer frequency if not provided
    inferred_freq: Optional[str] = None
    if freq is None:
        try:
            inferred_freq = pd.infer_freq(dt)
        except Exception:
            inferred_freq = None
    else:
        inferred_freq = freq
    # Determine the base period P for the harmonics
    def _period_from_freq(f: Optional[str], series_len: int) -> int:
        # Map pandas frequency strings to number of observations per cycle
        if f is None:
            return series_len
        # Normalize to uppercase and remove any numerical prefixes (e.g. '2W')
        import re
        match = re.match(r"(?i)(\d+)?([A-Za-z]+)", f)
        freq_str = match.group(2).upper() if match else f.upper()
        if freq_str.startswith('W'):
            return 52
        if freq_str.startswith('M'):
            return 12
        if freq_str.startswith('Q'):
            return 4
        if freq_str.startswith('A') or freq_str.startswith('Y'):
            return 1
        # Default fallback
        return series_len
    series_len = len(dt)
    P = _period_from_freq(inferred_freq, series_len)

    # Prepare future dates for forecasting
    # Determine effective frequency for date_range: if inferred_freq is None,
    # attempt to use the mode of date differences
    if inferred_freq is not None:
        try:
            # The start of forecast range is the next period after the last observed date
            future_index = pd.date_range(start=dt.iloc[-1], periods=periods + 1, freq=inferred_freq)[1:]
        except Exception:
            inferred_freq = None
            future_index = None
    if inferred_freq is None:
        # Fallback: compute the most common difference between successive dates
        diffs = dt.diff().dropna()
        if not diffs.empty:
            delta = diffs.mode()[0]
        else:
            delta = pd.Timedelta(days=1)
        future_index = pd.to_datetime([dt.iloc[-1] + delta * (i + 1) for i in range(periods)])

    # Storage for forecasts and coefficients
    forecasts_data: Dict[str, np.ndarray] = {}
    coefficients: Dict[str, np.ndarray] = {}

    # Build the design matrix for the observed data
    t = np.arange(series_len, dtype=float)
    # Always include intercept
    base_cols: List[np.ndarray] = [np.ones(series_len)]
    if include_trend:
        base_cols.append(t)
    for k in range(1, n_harmonics + 1):
        base_cols.append(np.sin(2 * np.pi * k * t / P))
        base_cols.append(np.cos(2 * np.pi * k * t / P))
    X = np.column_stack(base_cols)
    # Construct design matrix for future periods
    t_future = np.arange(series_len, series_len + periods, dtype=float)
    base_future_cols: List[np.ndarray] = [np.ones_like(t_future)]
    if include_trend:
        base_future_cols.append(t_future)
    for k in range(1, n_harmonics + 1):
        base_future_cols.append(np.sin(2 * np.pi * k * t_future / P))
        base_future_cols.append(np.cos(2 * np.pi * k * t_future / P))
    X_future = np.column_stack(base_future_cols)

    for col in numeric_cols:
        y = df[col].astype(float).values
        # Handle missing values via simple forward fill/backfill
        if np.any(pd.isna(y)):
            # Forward/backward fill missing values to avoid NaNs in regression
            y_series = pd.Series(y).ffill().bfill()
            y = y_series.values
        # Solve for coefficients via least squares
        beta, *_ = np.linalg.lstsq(X, y, rcond=None)
        coefficients[col] = beta
        # Forecast future values
        y_future = X_future @ beta
        forecasts_data[col] = y_future

    forecasts_df = pd.DataFrame(forecasts_data, index=future_index)
    return HarmonicForecastResult(forecasts=forecasts_df, coefficients=coefficients)


@dataclass
class EwmaVolatilityResult:
    """Result container for :func:`ewma_volatility`.

    Attributes
    ----------
    volatility : pandas.DataFrame
        DataFrame of the EWMA volatility for each numeric input column.
        The index matches the input dates; initial rows may be NaN due
        to differencing.
    returns : pandas.DataFrame
        DataFrame of the returns used to compute the volatility.  The
        first row for each column is NaN because of differencing.
    """

    volatility: pd.DataFrame
    returns: pd.DataFrame


def ewma_volatility(
    date: Union[str, pd.Series, Iterable],
    df: pd.DataFrame,
    *,
    span: int = 20,
    return_type: str = 'pct',
    annualize_factor: Optional[float] = None,
) -> EwmaVolatilityResult:
    """Compute the exponentially weighted moving average (EWMA) volatility.

    This function calculates the EWMA volatility of returns for each
    numeric column in the input DataFrame.  Returns can be either
    percentage changes or log differences.  The result is useful for
    estimating market risk or variability in energy series, analogous
    to volatility calculations in finance.

    Parameters
    ----------
    date : Union[str, pandas.Series, Iterable]
        Column name, series, or iterable containing date information.
        Used solely to align the resulting volatility series with the
        original index.  If a string is provided, it must be a column
        in ``df``.
    df : pandas.DataFrame
        DataFrame containing the data columns.  Numeric columns will
        each be processed independently.  Non‑numeric columns (other
        than ``date``) are ignored.
    span : int, default 20
        Span parameter for the exponential weighting.  Higher values
        assign more weight to older observations and produce smoother
        volatility estimates.  See ``pandas.Series.ewm`` documentation
        for details.
    return_type : {'pct', 'log'}, default 'pct'
        Type of return calculation: 'pct' uses percentage change
        (``(x_t / x_{t-1}) - 1``); 'log' uses the difference of natural
        logarithms (``log(x_t) - log(x_{t-1})``).  Log returns are
        often preferred when values span several orders of magnitude.
    annualize_factor : float or None, default None
        If provided, the resulting volatility will be multiplied by
        ``sqrt(annualize_factor)`` to convert it to an annualised
        measure.  For example, use 252 for daily trading data or 52
        for weekly data.  If ``None``, the volatility is reported at
        the same frequency as the input data.

    Returns
    -------
    EwmaVolatilityResult
        A dataclass containing the volatility and returns DataFrames.
    """
    # Determine the date index
    if isinstance(date, str):
        if date not in df.columns:
            raise KeyError(f"Date column '{date}' not found in DataFrame")
        date_series = df[date]
    else:
        date_series = pd.Series(date)
    dt = pd.to_datetime(date_series)
    if dt.empty:
        raise ValueError("Date series is empty")

    # Identify numeric columns
    numeric_cols: List[str] = []
    for col in df.columns:
        if col == date:
            continue
        if pd.api.types.is_numeric_dtype(df[col]):
            numeric_cols.append(col)
    if not numeric_cols:
        raise ValueError("No numeric data columns found for EWMA volatility")

    returns_data: Dict[str, np.ndarray] = {}
    volatility_data: Dict[str, np.ndarray] = {}

    for col in numeric_cols:
        series = df[col].astype(float)
        # Fill missing values to avoid propagation of NaN in pct_change/log diff
        s_filled = series.ffill().bfill()
        if return_type == 'log':
            ret = np.log(s_filled).diff()
        elif return_type == 'pct':
            ret = s_filled.pct_change()
        else:
            raise ValueError("return_type must be 'pct' or 'log'")
        # Compute EWMA standard deviation
        vol = ret.ewm(span=span, adjust=False).std()
        if annualize_factor is not None:
            vol = vol * np.sqrt(annualize_factor)
        returns_data[col] = ret.values
        volatility_data[col] = vol.values

    returns_df = pd.DataFrame(returns_data, index=dt)
    volatility_df = pd.DataFrame(volatility_data, index=dt)
    return EwmaVolatilityResult(volatility=volatility_df, returns=returns_df)


# -----------------------------------------------------------------------------
# Monte Carlo simulation for price paths
# -----------------------------------------------------------------------------

@dataclass
class MonteCarloSimulationResult:
    """Result container for :func:`monte_carlo_simulation`.

    Attributes
    ----------
    paths : Dict[str, pd.DataFrame]
        A mapping from each numeric input series to a DataFrame of
        simulated price paths.  The index corresponds to the forecast
        dates and each column corresponds to one simulation (from 0 to
        ``n_simulations - 1``).
    """

    paths: Dict[str, pd.DataFrame]


def monte_carlo_simulation(
    date: Union[str, pd.Series, Iterable],
    df: pd.DataFrame,
    *,
    periods: int = 20,
    n_simulations: int = 100,
    process: str = 'gbm',
    drift: Optional[float] = None,
    volatility: Optional[float] = None,
    mean_reversion_speed: float = 0.3,
    mean_reversion_level: Optional[float] = None,
    random_state: Optional[int] = None,
) -> MonteCarloSimulationResult:
    """Simulate future price paths using Monte Carlo techniques.

    This function generates random future price paths for each numeric
    series in ``df`` based on either a geometric Brownian motion (GBM)
    model or an Ornstein–Uhlenbeck (OU) mean‑reverting process.  Both
    models assume continuous compounding of returns and independent
    normally distributed shocks.  When drift and volatility are not
    supplied, they are estimated from historical log returns of each
    series.

    Parameters
    ----------
    date : Union[str, pandas.Series, Iterable]
        Column name, series, or iterable containing date information.
        Used to determine the end date of the historical data and to
        generate a sequence of future dates.  If a string is provided
        it must be a column in ``df``.
    df : pandas.DataFrame
        DataFrame containing the data columns.  Numeric columns will
        each be simulated independently.  Non‑numeric columns (other
        than ``date``) are ignored.
    periods : int, default 20
        Number of time steps to simulate into the future.  The time
        increment between steps is assumed to match the frequency of
        the input data.
    n_simulations : int, default 100
        Number of independent simulation paths to generate for each
        series.
    process : {'gbm', 'ou'}, default 'gbm'
        Type of stochastic process to use.  ``'gbm'`` denotes geometric
        Brownian motion (standard in financial modelling).  ``'ou'``
        denotes an Ornstein–Uhlenbeck mean‑reverting process applied to
        the log prices.
    drift : float or None, default None
        Drift parameter (expected return) for the GBM process.  If
        ``None``, it is estimated as the mean of historical log returns
        for each series.  For the OU process this parameter is not
        used; see ``mean_reversion_level`` instead.
    volatility : float or None, default None
        Volatility parameter (standard deviation of returns) for the
        process.  If ``None``, it is estimated as the standard deviation
        of historical log returns for each series.  For the OU process
        this governs the volatility of the log price.
    mean_reversion_speed : float, default 0.3
        Speed of mean reversion (kappa) for the OU process.  Higher
        values force the series to revert more quickly toward the long‑run
        mean.
    mean_reversion_level : float or None, default None
        Long‑run mean level (theta) for the OU process.  If ``None``,
        it is set to the mean of the historical log prices for each
        series.  Only used when ``process='ou'``.
    random_state : int or None, default None
        Optional seed for the random number generator to ensure
        reproducibility.

    Returns
    -------
    MonteCarloSimulationResult
        A dataclass containing the simulated paths for each numeric
        series.  For each series the result is a DataFrame indexed by
        future dates, with one column per simulation.

    Notes
    -----
    *Geometric Brownian motion*: The GBM model updates prices as
    ``S_{t+1} = S_t * exp((mu - 0.5 * sigma^2) * dt + sigma * sqrt(dt) * Z)``,
    where ``Z`` is a standard normal random variable.

    *Ornstein–Uhlenbeck*: The OU model operates on the log price
    ``X_t = log(S_t)`` via
    ``X_{t+1} = X_t + kappa*(theta - X_t)*dt + sigma*sqrt(dt)*Z``.  The
    simulated log prices are exponentiated to obtain prices.
    """
    # Determine the date series
    if isinstance(date, str):
        if date not in df.columns:
            raise KeyError(f"Date column '{date}' not found in DataFrame")
        date_series = df[date]
    else:
        date_series = pd.Series(date)
    dt_idx = pd.to_datetime(date_series)
    if dt_idx.empty:
        raise ValueError("Date series is empty")

    # Identify numeric columns
    numeric_cols: List[str] = []
    for col in df.columns:
        if isinstance(date, str) and col == date:
            continue
        if pd.api.types.is_numeric_dtype(df[col]):
            numeric_cols.append(col)
    if not numeric_cols:
        raise ValueError("No numeric data columns found for simulation")

    # Generate future dates based on inferred frequency
    try:
        freq = pd.infer_freq(dt_idx)
    except Exception:
        freq = None
    if freq is not None:
        try:
            future_dates = pd.date_range(start=dt_idx.iloc[-1], periods=periods + 1, freq=freq)[1:]
        except Exception:
            freq = None
            future_dates = None
    if freq is None:
        diffs = dt_idx.diff().dropna()
        if not diffs.empty:
            delta = diffs.median()
        else:
            delta = pd.Timedelta(days=1)
        future_dates = pd.to_datetime([dt_idx.iloc[-1] + delta * (i + 1) for i in range(periods)])

    rng = np.random.default_rng(random_state)
    paths: Dict[str, pd.DataFrame] = {}

    for col in numeric_cols:
        series = df[col].astype(float)
        s_filled = series.ffill().bfill()
        last_price = s_filled.iloc[-1]
        log_returns = np.log(s_filled).diff().dropna()
        mu_est = log_returns.mean() if drift is None else drift
        sigma_est = log_returns.std() if volatility is None else volatility
        theta_est = np.log(s_filled).mean() if mean_reversion_level is None else mean_reversion_level
        sims = np.empty((periods, n_simulations), dtype=float)
        if process.lower() == 'gbm':
            for sim in range(n_simulations):
                prices = np.empty(periods, dtype=float)
                prev_price = last_price
                for t_step in range(periods):
                    z = rng.standard_normal()
                    new_price = prev_price * np.exp((mu_est - 0.5 * sigma_est ** 2) + sigma_est * z)
                    prices[t_step] = new_price
                    prev_price = new_price
                sims[:, sim] = prices
        elif process.lower() == 'ou':
            last_log = np.log(last_price)
            for sim in range(n_simulations):
                log_prices = np.empty(periods, dtype=float)
                prev_log = last_log
                for t_step in range(periods):
                    z = rng.standard_normal()
                    new_log = prev_log + mean_reversion_speed * (theta_est - prev_log) + sigma_est * z
                    log_prices[t_step] = new_log
                    prev_log = new_log
                sims[:, sim] = np.exp(log_prices)
        else:
            raise ValueError("process must be 'gbm' or 'ou'")
        col_df = pd.DataFrame(sims, index=future_dates, columns=[f"sim_{i}" for i in range(n_simulations)])
        paths[col] = col_df
    return MonteCarloSimulationResult(paths=paths)


# -----------------------------------------------------------------------------
# Value at Risk (VaR) and Conditional VaR (CVaR)
# -----------------------------------------------------------------------------

@dataclass
class VaRResult:
    """Result container for :func:`value_at_risk`.

    Attributes
    ----------
    var : pandas.DataFrame
        DataFrame with the Value at Risk for each input series.  Rows
        correspond to the confidence level(s) provided.
    cvar : pandas.DataFrame
        DataFrame with the Conditional Value at Risk (expected shortfall)
        for each input series.  Rows correspond to the confidence
        level(s) provided.
    method : str
        The method used to compute VaR and CVaR ('parametric' or
        'historical').
    """

    var: pd.DataFrame
    cvar: pd.DataFrame
    method: str


def value_at_risk(
    date: Union[str, pd.Series, Iterable],
    df: pd.DataFrame,
    *,
    alpha: Union[float, List[float]] = 0.95,
    method: str = 'parametric',
    return_type: str = 'pct',
    horizon: int = 1,
) -> VaRResult:
    """Compute Value at Risk (VaR) and Conditional VaR (CVaR) for each series.

    VaR is a measure of the maximum expected loss over a given time
    horizon at a specified confidence level.  CVaR (also called
    Expected Shortfall) is the expected loss conditional on the loss
    exceeding the VaR threshold.  Both metrics are widely used by
    commodity traders and risk managers to quantify potential losses.

    Parameters
    ----------
    date : Union[str, pandas.Series, Iterable]
        Column name, series or iterable of dates.  Used to align the
        return series; otherwise ignored.
    df : pandas.DataFrame
        DataFrame containing the numeric data columns.  Returns are
        computed for each column separately.  Non‑numeric columns (other
        than ``date``) are ignored.
    alpha : float or list of floats, default 0.95
        Confidence level(s) for the VaR calculation.  A value of 0.95
        corresponds to the 95 % confidence level.  Multiple levels can
        be provided for a table of results.
    method : {'parametric', 'historical'}, default 'parametric'
        Method to compute VaR and CVaR.  ``'parametric'`` assumes the
        return distribution is normal with mean and standard deviation
        estimated from the historical returns.  ``'historical'`` uses
        the empirical distribution of returns.
    return_type : {'pct', 'log'}, default 'pct'
        Type of return calculation: percentage change (``'pct'``) or
        log return (``'log'``).  Returns are assumed to be daily (or
        per period) and will be scaled by ``sqrt(horizon)`` in the
        parametric case.
    horizon : int, default 1
        Number of periods over which to compute the VaR.  For example,
        with daily returns and ``horizon=5``, the function computes the
        five‑day VaR by scaling the standard deviation by
        ``sqrt(horizon)``.  Only used in the parametric method.

    Returns
    -------
    VaRResult
        A dataclass containing DataFrames of VaR and CVaR values for
        each input series and each confidence level.

    Notes
    -----
    In the parametric method, CVaR is computed assuming a normal
    distribution: ``CVaR = -(mu * horizon + sigma * sqrt(horizon) *
    phi(z) / (1 - alpha))``, where ``phi`` is the standard normal PDF
    and ``z`` is the quantile corresponding to the confidence level.  In
    the historical method, CVaR is the average of returns less than or
    equal to the VaR.
    """
    # Determine numeric columns
    numeric_cols: List[str] = []
    for col in df.columns:
        if isinstance(date, str) and col == date:
            continue
        if pd.api.types.is_numeric_dtype(df[col]):
            numeric_cols.append(col)
    if not numeric_cols:
        raise ValueError("No numeric data columns found for VaR")
    # Convert alpha to list
    if isinstance(alpha, float):
        alphas = [alpha]
    else:
        alphas = list(alpha)
    var_matrix = pd.DataFrame(index=alphas, columns=numeric_cols, dtype=float)
    cvar_matrix = pd.DataFrame(index=alphas, columns=numeric_cols, dtype=float)
    for col in numeric_cols:
        series = df[col].astype(float)
        s_filled = series.ffill().bfill()
        if return_type == 'log':
            returns = np.log(s_filled).diff().dropna()
        elif return_type == 'pct':
            returns = s_filled.pct_change().dropna()
        else:
            raise ValueError("return_type must be 'pct' or 'log'")
        if method.lower() == 'parametric':
            mu = returns.mean()
            sigma = returns.std()
            mu_h = mu * horizon
            sigma_h = sigma * np.sqrt(horizon)
            from scipy.stats import norm
            for a in alphas:
                z = norm.ppf(1 - a)
                var_value = -(mu_h + sigma_h * z)
                cvar_value = -(mu_h + sigma_h * norm.pdf(z) / (1 - a))
                var_matrix.loc[a, col] = var_value
                cvar_matrix.loc[a, col] = cvar_value
        elif method.lower() == 'historical':
            sorted_returns = returns.sort_values()
            n = len(sorted_returns)
            for a in alphas:
                idx = int(np.ceil((1 - a) * n)) - 1
                idx = max(0, min(idx, n - 1))
                var_value = -sorted_returns.iloc[idx]
                losses = sorted_returns[sorted_returns <= sorted_returns.iloc[idx]]
                cvar_value = -losses.mean() if not losses.empty else var_value
                var_matrix.loc[a, col] = var_value
                cvar_matrix.loc[a, col] = cvar_value
        else:
            raise ValueError("method must be 'parametric' or 'historical'")
    return VaRResult(var=var_matrix, cvar=cvar_matrix, method=method)


# -----------------------------------------------------------------------------
# Cointegration testing and spread analysis
# -----------------------------------------------------------------------------

@dataclass
class CointegrationTestResult:
    """Result container for :func:`cointegration_test`.

    Attributes
    ----------
    test_statistic : float
        The Engle–Granger test statistic for the null hypothesis of no
        cointegration.
    p_value : float
        The p‑value corresponding to the test statistic.
    critical_values : dict
        Critical values at the 1 %, 5 % and 10 % significance levels.
    residuals : pandas.Series
        The residuals from the cointegration regression (series1 on
        series2).
    z_score : pandas.Series
        The z‑score of the residuals, computed as
        ``(residual - residual.mean()) / residual.std()``.  Z‑scores can
        be used to identify deviations from the long‑run relationship for
        potential trading signals.
    """
    test_statistic: float
    p_value: float
    critical_values: Dict[str, float]
    residuals: pd.Series
    z_score: pd.Series


def cointegration_test(
    date: Union[str, pd.Series, Iterable],
    df: pd.DataFrame,
    series1: str,
    series2: str,
    *,
    significance: float = 0.05,
) -> CointegrationTestResult:
    """Perform an Engle–Granger cointegration test between two series.

    Cointegration analysis is fundamental for identifying pairs of assets
    (such as commodities) that share a stable long‑run relationship.  A
    statistically significant test indicates that deviations between the
    two series may be mean‑reverting, which can be exploited in
    spread‑trading strategies.

    Parameters
    ----------
    date : Union[str, pandas.Series, Iterable]
        Column name, series or iterable of dates.  Used to align the
        input series.  If a string is provided it must be a column in
        ``df``.
    df : pandas.DataFrame
        DataFrame containing the two series to test.  Both series must
        be numeric.
    series1 : str
        Name of the first series (dependent variable in the regression).
    series2 : str
        Name of the second series (independent variable).
    significance : float, default 0.05
        Significance level for interpreting the p‑value.  Results are
        returned regardless of this value; it is provided for
        convenience in determining whether the null hypothesis of no
        cointegration can be rejected.

    Returns
    -------
    CointegrationTestResult
        A dataclass containing the test statistic, p‑value, critical
        values, residuals and z‑score of the residuals.

    Notes
    -----
    The Engle–Granger test performs a regression of ``series1`` on
    ``series2`` and examines the residuals for stationarity using the
    augmented Dickey–Fuller (ADF) test.  If the residuals are
    stationary, the series are cointegrated.
    """
    if series1 not in df.columns or series2 not in df.columns:
        raise KeyError(f"Specified series '{series1}' or '{series2}' not found in DataFrame")
    if isinstance(date, str):
        if date not in df.columns:
            raise KeyError(f"Date column '{date}' not found in DataFrame")
        date_series = df[date]
    else:
        date_series = pd.Series(date)
    dt = pd.to_datetime(date_series)
    s1 = df[series1].astype(float)
    s2 = df[series2].astype(float)
    valid_mask = ~(s1.isna() | s2.isna())
    s1 = s1[valid_mask]
    s2 = s2[valid_mask]
    dt = dt[valid_mask]
    try:
        from statsmodels.tsa.stattools import coint
    except ImportError as e:
        raise ImportError("statsmodels is required for cointegration test") from e
    test_stat, p_value, crit_values = coint(s1, s2)
    import statsmodels.api as sm
    X = sm.add_constant(s2.values)
    model = sm.OLS(s1.values, X).fit()
    resid = pd.Series(model.resid, index=dt)
    z_score = (resid - resid.mean()) / resid.std(ddof=0)
    crit_map = {'1%': crit_values[0], '5%': crit_values[1], '10%': crit_values[2]}
    return CointegrationTestResult(
        test_statistic=test_stat,
        p_value=p_value,
        critical_values=crit_map,
        residuals=resid,
        z_score=z_score,
    )


# -----------------------------------------------------------------------------
# Spectral density plot using Fourier transform
# -----------------------------------------------------------------------------

def spectral_density_plot(
    date: Union[str, pd.Series, Iterable],
    df: pd.DataFrame,
    *,
    columns: Optional[List[str]] = None,
    sampling_rate: float = 1.0,
) -> 'go.Figure':
    """Compute and plot the spectral density of one or more time series.

    This function uses the Fourier transform to estimate the power
    spectral density (PSD) of each selected series.  The PSD reveals
    dominant cycles and periodicities in the data.  A Plotly figure is
    returned with one subplot per series.

    Parameters
    ----------
    date : Union[str, pandas.Series, Iterable]
        Column name, series or iterable of dates.  Dates are not used in
        the FFT calculation but are required to align the data and
        identify numeric columns.
    df : pandas.DataFrame
        DataFrame containing the data.  If ``columns`` is ``None``, all
        numeric columns (except the date) are processed.
    columns : list of str or None, default None
        Specific columns to process.  If ``None``, all numeric columns
        are used.
    sampling_rate : float, default 1.0
        Sampling rate of the time series (number of observations per
        unit time).  Use 1 for unit interval data.  If your data are
        daily, ``sampling_rate=1`` yields frequencies in cycles per day;
        for weekly data use 1/7.

    Returns
    -------
    plotly.graph_objects.Figure
        An interactive figure with PSD plots.  Each subplot displays
        frequency (x‑axis) against power spectral density (y‑axis).
    """
    import numpy as np
    from scipy.signal import periodogram
    import plotly.graph_objects as go
    from plotly.subplots import make_subplots

    # Determine date series (used only to exclude date column if string)
    if isinstance(date, str):
        if date not in df.columns:
            raise KeyError(f"Date column '{date}' not found in DataFrame")
    else:
        date = None

    # Select columns
    if columns is None:
        columns = [c for c in df.columns if c != date and pd.api.types.is_numeric_dtype(df[c])]
    if not columns:
        raise ValueError("No numeric columns found for spectral density plot")

    n_plots = len(columns)
    fig = make_subplots(rows=n_plots, cols=1, shared_xaxes=False, vertical_spacing=0.08)
    for idx, col in enumerate(columns, start=1):
        y = df[col].astype(float).ffill().bfill().values
        # Remove mean to avoid DC component dominating
        y = y - y.mean()
        freqs, psd = periodogram(y, fs=sampling_rate)
        fig.add_trace(
            go.Scatter(
                x=freqs,
                y=psd,
                mode='lines',
                line=dict(color='blue'),
                name=f"PSD {col}" if idx == 1 else None,
                showlegend=idx == 1,
            ),
            row=idx,
            col=1,
        )
        fig.update_yaxes(title_text=f"PSD of {col}", row=idx, col=1)
    fig.update_xaxes(title_text="Frequency", row=n_plots, col=1)
    fig.update_layout(height=350 * n_plots, title="Power Spectral Density")
    return fig


# -----------------------------------------------------------------------------
# Wavelet spectrogram for time–frequency analysis
# -----------------------------------------------------------------------------

def wavelet_spectrogram(
    date: Union[str, pd.Series, Iterable],
    df: pd.DataFrame,
    *,
    column: str,
    widths: Optional[np.ndarray] = None,
    wavelet: str = 'ricker',
    sampling_rate: float = 1.0,
) -> 'go.Figure':
    """Compute and plot a continuous wavelet transform (CWT) spectrogram.

    The CWT provides a time–frequency representation of a signal using
    localized wavelets.  This function applies SciPy's CWT with either
    the Ricker (Mexican hat) or Morlet wavelet and visualizes the
    magnitude of the transform as a heatmap.  A spectrogram can be
    useful for detecting transient events or frequency shifts in
    commodity price data.

    Parameters
    ----------
    date : Union[str, pandas.Series, Iterable]
        Column name, series or iterable of dates.  Used only to align
        the selected column.
    df : pandas.DataFrame
        DataFrame containing the data.  Must include ``column``.
    column : str
        Name of the column to transform.
    widths : numpy.ndarray or None, default None
        Array of scales (widths) at which to compute the CWT.  If
        ``None``, a default range of widths from 1 to 64 is used.  Larger
        widths correspond to lower frequencies (longer periods).
    wavelet : {'ricker', 'morlet'}, default 'ricker'
        Type of mother wavelet to use.  The Morlet wavelet provides
        better frequency localization but is more oscillatory.
    sampling_rate : float, default 1.0
        Sampling rate of the time series (observations per unit time).
        Used to label the y‑axis in terms of approximate periods.

    Returns
    -------
    plotly.graph_objects.Figure
        A figure containing the spectrogram heatmap.
    """
    import numpy as np
    import plotly.graph_objects as go
    from scipy import signal

    # Extract series
    if column not in df.columns:
        raise KeyError(f"Column '{column}' not found in DataFrame")
    y = df[column].astype(float).ffill().bfill().values
    n = len(y)
    if widths is None:
        widths = np.arange(1, min(64, n // 2))
    # Choose wavelet function
    if wavelet.lower() == 'ricker':
        wavelet_func = signal.ricker
    elif wavelet.lower() == 'morlet':
        # SciPy's morlet returns a complex wavelet; we use cwt with Morlet
        def morlet_wavelet(length, w):
            # SciPy's morlet wavelet uses parameter w; default w=5
            return signal.morlet(length, w)
        wavelet_func = morlet_wavelet
    else:
        raise ValueError("wavelet must be 'ricker' or 'morlet'")
    # Compute CWT
    cwt_matrix = signal.cwt(y, wavelet_func, widths)
    power = np.abs(cwt_matrix) ** 2
    # Convert widths to pseudo periods
    periods = widths / sampling_rate
    # Create heatmap
    fig = go.Figure(data=go.Heatmap(
        z=power,
        x=np.arange(n),
        y=periods,
        colorscale='Viridis'
    ))
    fig.update_layout(
        title=f"Wavelet Spectrogram for {column}",
        xaxis_title="Time (index)",
        yaxis_title="Scale (approx. period)",
        height=400
    )
    return fig


# -----------------------------------------------------------------------------
# Kalman smoother for noise reduction
# -----------------------------------------------------------------------------

def kalman_smoother(
    date: Union[str, pd.Series, Iterable],
    df: pd.DataFrame,
    *,
    process_variance: Optional[float] = None,
    measurement_variance: Optional[float] = None,
    initial_variance: float = 1.0,
) -> pd.DataFrame:
    """Apply a simple Kalman filter and smoother to numeric series.

    The Kalman filter provides an optimal recursive estimate of the
    underlying state of a process that is observed with noise.  Here
    each numeric series is modelled as a random walk observed through
    measurement noise.  The function returns the smoothed series for
    each column, which can be useful for de‑noising volatile price data.

    Parameters
    ----------
    date : Union[str, pandas.Series, Iterable]
        Column name, series or iterable of dates.  Used to align the
        output DataFrame with the input index.
    df : pandas.DataFrame
        DataFrame containing the data.  Numeric columns will be
        smoothed independently.  Non‑numeric columns (other than
        ``date``) are ignored.
    process_variance : float or None, default None
        Variance of the process noise (Q).  If ``None``, it is set to
        1e‑5 times the variance of the input series.  Larger values
        allow the state estimate to follow the observed data more
        closely.
    measurement_variance : float or None, default None
        Variance of the measurement noise (R).  If ``None``, it is set
        to the variance of the input series.  Smaller values place
        greater trust in the observations.
    initial_variance : float, default 1.0
        Initial estimate of the error covariance P_0.

    Returns
    -------
    pandas.DataFrame
        DataFrame of the smoothed series, indexed by the input dates and
        with one column per numeric series.
    """
    import numpy as np
    # Determine date index
    if isinstance(date, str):
        if date not in df.columns:
            raise KeyError(f"Date column '{date}' not found in DataFrame")
        date_idx = pd.to_datetime(df[date])
    else:
        date_idx = pd.to_datetime(pd.Series(date))
    # Identify numeric columns
    numeric_cols = [c for c in df.columns if c != date and pd.api.types.is_numeric_dtype(df[c])]
    if not numeric_cols:
        raise ValueError("No numeric columns found for Kalman smoother")
    result = pd.DataFrame(index=date_idx, columns=numeric_cols, dtype=float)
    for col in numeric_cols:
        y = df[col].astype(float).ffill().bfill().values
        n = len(y)
        # Initialize parameters
        var_y = np.var(y)
        q = process_variance if process_variance is not None else var_y * 1e-5
        r = measurement_variance if measurement_variance is not None else var_y
        x_est = y[0]
        p = initial_variance
        # storage for filtered values
        x_filtered = np.zeros(n)
        # Filtering step
        for i in range(n):
            # Prediction: x_pred = x_est (random walk), p_pred = p + q
            p_pred = p + q
            # Update: Kalman gain
            k = p_pred / (p_pred + r)
            x_est = x_est + k * (y[i] - x_est)
            p = (1 - k) * p_pred
            x_filtered[i] = x_est
        # Backward smoothing (Rauch–Tung–Striebel)
        x_smooth = np.copy(x_filtered)
        p_smooth = p
        for i in range(n - 2, -1, -1):
            p_pred = p_smooth + q
            g = p_smooth / p_pred
            x_smooth[i] = x_filtered[i] + g * (x_smooth[i + 1] - x_filtered[i])
            p_smooth = p_smooth + q - g * p_smooth
        result[col] = x_smooth
    return result


# -----------------------------------------------------------------------------
# Resample time series between daily, weekly and monthly frequencies
# -----------------------------------------------------------------------------

def resample_time_series(
    date: Union[str, pd.Series, Iterable],
    df: pd.DataFrame,
    *,
    target_freq: str,
    agg: str = 'mean',
    week_ending: str = 'Fri',
    return_type: str = 'daily',
) -> pd.DataFrame:
    """Resample a time series DataFrame to a different frequency.

    This utility can downsample daily data to weekly or monthly data
    using aggregation functions such as mean, sum, min or max.  It can
    also upsample weekly or monthly data to daily frequency using cubic
    spline interpolation.  When downsampling to weekly data, the user
    may choose which day of the week marks the end of the week (e.g.
    Friday).  Monthly data always end on the last day of the month.

    Parameters
    ----------
    date : Union[str, pandas.Series, Iterable]
        Column name, series or iterable of dates.  The date column is
        converted to datetime and used as the index for resampling.
    df : pandas.DataFrame
        DataFrame containing the data to resample.  Numeric columns are
        processed; non‑numeric columns (other than ``date``) are
        ignored.
    target_freq : str
        Target frequency code: 'D' for daily, 'W' for weekly or 'M' for
        monthly.  When target frequency is coarser than the observed
        frequency (e.g. daily → weekly), aggregation is applied.
        When target frequency is finer (e.g. weekly → daily), cubic
        spline interpolation is used.
    agg : {'mean', 'sum', 'min', 'max', 'first', 'last', 'median'}, default 'mean'
        Aggregation method for downsampling.
    week_ending : str, default 'Fri'
        Day of the week that marks the end of a week when resampling to
        weekly frequency.  Any of {'Mon','Tue','Wed','Thu','Fri','Sat','Sun'}.
    return_type : {'daily','weekly','monthly'}, default 'daily'
        A descriptive string for the desired return type.  This is
        included for clarity; it must be consistent with ``target_freq``.

    Returns
    -------
    pandas.DataFrame
        A DataFrame at the desired frequency, with numeric columns
        resampled or interpolated accordingly.

    Notes
    -----
    For interpolation when upsampling, a cubic spline (via SciPy) is
    fitted through the observed data for each series.  This yields
    smooth estimates at the higher frequency.  Users should ensure
    their data do not contain large gaps when using interpolation.
    """
    import numpy as np
    # Validate target frequency and return_type
    valid_freqs = {'D': 'daily', 'W': 'weekly', 'M': 'monthly'}
    if target_freq not in valid_freqs:
        raise ValueError("target_freq must be one of 'D', 'W' or 'M'")
    if return_type != valid_freqs[target_freq]:
        # Allow synonyms: treat 'W' as weekly etc.
        pass
    # Determine the date index
    if isinstance(date, str):
        if date not in df.columns:
            raise KeyError(f"Date column '{date}' not found in DataFrame")
        date_idx = pd.to_datetime(df[date])
    else:
        date_idx = pd.to_datetime(pd.Series(date))
    # Set index
    data = df.copy()
    data.index = date_idx
    # Identify numeric columns
    numeric_cols = [c for c in data.columns if c != date and pd.api.types.is_numeric_dtype(data[c])]
    if not numeric_cols:
        raise ValueError("No numeric columns found for resampling")
    data = data[numeric_cols]
    # Determine source frequency
    try:
        src_freq = pd.infer_freq(date_idx)
    except Exception:
        src_freq = None
    # For weekly resampling, incorporate week_ending into freq code
    if target_freq == 'W':
        freq_code = f"W-{week_ending.capitalize()[:3]}"
    else:
        freq_code = target_freq
    # Determine whether to downsample or upsample
    downsample = None
    # Use heuristics: if source frequency is daily ('D') and target weekly or monthly, downsample
    if src_freq is not None:
        if src_freq.upper().startswith('D') and target_freq in ['W','M']:
            downsample = True
        elif src_freq.upper().startswith('W') and target_freq == 'M':
            downsample = True
        elif src_freq.upper().startswith('M') and target_freq in ['W','D']:
            downsample = False
        elif src_freq.upper().startswith('W') and target_freq == 'D':
            downsample = False
    # Fallback: compare approximate period lengths (days)
    if downsample is None:
        freq_to_days = {'D': 1, 'W': 7, 'M': 30}
        src_days = freq_to_days.get(src_freq[0].upper(), 1) if src_freq else 1
        tgt_days = freq_to_days.get(target_freq, 1)
        downsample = src_days <= tgt_days
    if downsample:
        # Downsample using aggregation
        if agg not in ['mean','sum','min','max','first','last','median']:
            raise ValueError("Invalid aggregation function")
        if target_freq == 'M':
            resampled = data.resample(freq_code).agg(agg)
        else:
            resampled = data.resample(freq_code, label='right').agg(agg)
        return resampled
    else:
        # Upsample using spline interpolation
        from scipy.interpolate import make_interp_spline
        # Determine new index
        if target_freq == 'M':
            # Upsample monthly data to weekly or daily: create new index with target frequency
            new_index = pd.date_range(start=data.index.min(), end=data.index.max(), freq=freq_code)
        else:
            new_index = pd.date_range(start=data.index.min(), end=data.index.max(), freq=freq_code)
        interp_df = pd.DataFrame(index=new_index)
        for col in numeric_cols:
            # Prepare data for interpolation
            x = np.array(data.index.view('int64'), dtype=float)
            y = data[col].astype(float).values
            # Remove NaNs
            mask = ~np.isnan(y)
            x_clean = x[mask]
            y_clean = y[mask]
            if len(x_clean) < 4:
                # Not enough points for spline; use linear interpolation
                interp_df[col] = np.interp(new_index.view('int64'), x_clean, y_clean)
                continue
            try:
                spline = make_interp_spline(x_clean, y_clean, k=3)
                y_new = spline(new_index.view('int64'))
                interp_df[col] = y_new
            except Exception:
                # Fallback to linear interpolation
                interp_df[col] = np.interp(new_index.view('int64'), x_clean, y_clean)
        return interp_df


# -----------------------------------------------------------------------------
# STL decomposition and plotting
# -----------------------------------------------------------------------------

def stl_decompose_plot(
    date: Union[str, pd.Series, Iterable],
    df: pd.DataFrame,
    *,
    columns: Optional[List[str]] = None,
    period: Optional[int] = None,
) -> 'go.Figure':
    """Perform STL decomposition and plot trend, seasonal and residual components.

    Seasonal–Trend decomposition using Loess (STL) separates a time
    series into additive components: trend, seasonal and residual.  This
    function applies STL to each specified series and produces a Plotly
    figure with four subplots per series showing the original series,
    trend, seasonal component and residual.  The period (seasonal
    length) can be specified or inferred from the data frequency.

    Parameters
    ----------
    date : Union[str, pandas.Series, Iterable]
        Column name, series or iterable of dates.  Dates are used to
        align the series and are converted to pandas ``datetime64``.
    df : pandas.DataFrame
        DataFrame containing the data.  Numeric columns are decomposed.
    columns : list of str or None, default None
        Specific columns to decompose.  If ``None``, all numeric
        columns except the date column are used.
    period : int or None, default None
        Seasonal period to use for STL.  If ``None``, the period is
        inferred from the frequency of the date series: 365 for daily
        data, 52 for weekly data and 12 for monthly data.

    Returns
    -------
    plotly.graph_objects.Figure
        A figure with the decomposition for each series.
    """
    import plotly.graph_objects as go
    from plotly.subplots import make_subplots
    from statsmodels.tsa.seasonal import STL
    # Determine date index
    if isinstance(date, str):
        if date not in df.columns:
            raise KeyError(f"Date column '{date}' not found in DataFrame")
        date_idx = pd.to_datetime(df[date])
    else:
        date_idx = pd.to_datetime(pd.Series(date))
    # Determine columns
    if columns is None:
        columns = [c for c in df.columns if c != date and pd.api.types.is_numeric_dtype(df[c])]
    if not columns:
        raise ValueError("No numeric columns to decompose")
    # Infer period if not provided
    if period is None:
        try:
            freq = pd.infer_freq(date_idx)
        except Exception:
            freq = None
        if freq is not None:
            if freq.upper().startswith('D'):
                period = 365
            elif freq.upper().startswith('W'):
                period = 52
            elif freq.upper().startswith('M'):
                period = 12
            else:
                period = max(2, len(date_idx) // 2)
        else:
            period = max(2, len(date_idx) // 2)
    n_plots = len(columns)
    fig = make_subplots(rows=n_plots * 4, cols=1, shared_xaxes=True, vertical_spacing=0.02)
    row = 1
    for col in columns:
        series = df[col].astype(float).ffill().bfill().values
        stl = STL(series, period=period, robust=True)
        res = stl.fit()
        components = {
            'Observed': series,
            'Trend': res.trend,
            'Seasonal': res.seasonal,
            'Residual': res.resid,
        }
        for name, comp in components.items():
            fig.add_trace(
                go.Scatter(
                    x=date_idx,
                    y=comp,
                    mode='lines',
                    name=f"{name} ({col})" if row == 1 else None,
                    showlegend=row == 1,
                ),
                row=row,
                col=1,
            )
            fig.update_yaxes(title_text=f"{name} ({col})", row=row, col=1)
            row += 1
    fig.update_xaxes(title_text="Date", row=row - 1, col=1)
    fig.update_layout(height=250 * n_plots * 4, title="STL Decomposition")
    return fig