from __future__ import annotations

from dataclasses import dataclass
from typing import Optional

import numpy as np


@dataclass(slots=True)
class CoarseGrainSpec:
    """Static binning information for coarse-grained frequency data."""

    f_coarse: np.ndarray
    selection_mask: np.ndarray
    mask_low: np.ndarray
    mask_high: np.ndarray
    bin_indices: np.ndarray
    sort_indices: np.ndarray
    bin_counts: np.ndarray
    bin_widths: np.ndarray
    n_low: int
    n_bins_high: int
    fine_spacing: float


def compute_binning_structure(
    freqs: np.ndarray,
    *,
    f_transition: float,
    n_log_bins: int,
    f_min: Optional[float] = None,
    f_max: Optional[float] = None,
) -> CoarseGrainSpec:
    """Compute coarse-graining bins for a monotonically increasing frequency grid."""

    if freqs.ndim != 1:
        raise ValueError("freqs must be a 1-D array")
    if not np.all(np.diff(freqs) >= 0):
        raise ValueError("freqs must be monotonically increasing")
    if n_log_bins <= 0:
        raise ValueError("n_log_bins must be positive")
    if f_transition <= 0:
        raise ValueError("f_transition must be positive")

    freq_min = float(freqs[0])
    freq_max = float(freqs[-1])
    f_min = freq_min if f_min is None else float(f_min)
    f_max = freq_max if f_max is None else float(f_max)
    f_min = min(max(f_min, freq_min), freq_max)
    f_max = min(max(f_max, freq_min), freq_max)
    if f_max < f_min:
        f_max = f_min

    in_range = (freqs >= f_min) & (freqs <= f_max)
    if not np.any(in_range):
        raise ValueError("No frequencies fall within the requested range")

    selected_freqs = freqs[in_range]
    mask_low_full = (freqs >= f_min) & (freqs <= f_transition)
    mask_high_full = (freqs > f_transition) & (freqs <= f_max)

    mask_low = mask_low_full[in_range]
    mask_high = mask_high_full[in_range]

    n_low = int(mask_low.sum())

    if selected_freqs.size < 2:
        fine_spacing = 1.0
    else:
        freq_diffs = np.diff(selected_freqs)
        positive_diffs = freq_diffs[freq_diffs > 0]
        if positive_diffs.size == 0:
            raise ValueError(
                "Selected frequencies must contain increasing values."
            )
        fine_spacing = float(np.median(positive_diffs))

    high_freqs = selected_freqs[mask_high]
    if high_freqs.size == 0:
        # No high-frequency bins; only low frequencies retained
        return CoarseGrainSpec(
            f_coarse=selected_freqs,
            selection_mask=in_range,
            mask_low=mask_low,
            mask_high=mask_high,
            bin_indices=np.array([], dtype=np.int32),
            sort_indices=np.array([], dtype=np.int32),
            bin_counts=np.array([], dtype=np.int32),
            bin_widths=np.array([], dtype=np.float64),
            n_low=n_low,
            n_bins_high=0,
            fine_spacing=fine_spacing,
        )

    # Build logarithmic bin edges for the high-frequency region
    high_min = max(f_transition, high_freqs[0])
    edges = np.logspace(
        np.log10(high_min),
        np.log10(f_max),
        num=n_log_bins + 1,
        base=10.0,
    )
    # Ensure edges cover the entire high-frequency range
    edges[0] = min(edges[0], high_min)
    edges[-1] = max(edges[-1], f_max)

    # Assign bins (0-indexed). Points below first edge get bin -1; clamp to 0.
    raw_indices = np.digitize(high_freqs, edges[1:-1], right=False)
    unique_bins, reindexed = np.unique(raw_indices, return_inverse=True)
    n_bins_high = int(unique_bins.size)

    sort_indices = np.argsort(reindexed, kind="stable")
    sorted_bins = reindexed[sort_indices]

    # Count points per bin
    bin_counts = np.bincount(sorted_bins, minlength=n_bins_high)

    # Compute representative frequency per bin (mean of members)
    bin_sums = np.zeros(n_bins_high, dtype=np.float64)
    np.add.at(bin_sums, sorted_bins, high_freqs[sort_indices])
    bin_means = bin_sums / np.where(bin_counts > 0, bin_counts, 1)
    raw_widths = edges[unique_bins + 1] - edges[unique_bins]
    min_width = np.array(fine_spacing, dtype=np.float64)
    bin_widths = np.maximum(raw_widths, min_width)

    f_coarse = np.concatenate((selected_freqs[:n_low], bin_means))

    return CoarseGrainSpec(
        f_coarse=f_coarse,
        selection_mask=in_range,
        mask_low=mask_low,
        mask_high=mask_high,
        bin_indices=reindexed,
        sort_indices=sort_indices,
        bin_counts=bin_counts,
        bin_widths=bin_widths,
        n_low=n_low,
        n_bins_high=n_bins_high,
        fine_spacing=fine_spacing,
    )


def apply_coarse_graining_univar(
    power: np.ndarray,
    spec: CoarseGrainSpec,
    freqs: Optional[np.ndarray] = None,
) -> tuple[np.ndarray, np.ndarray]:
    """Apply coarse graining to a power array using ``CoarseGrainSpec``."""

    if power.ndim != 1:
        raise ValueError("power must be 1-D")
    if power.size != spec.mask_low.size:
        raise ValueError("power length must match selected frequency count")

    power_low = power[spec.mask_low]

    fine_spacing = spec.fine_spacing
    if freqs is not None:
        freqs = np.asarray(freqs, dtype=np.float64)
        if freqs.ndim != 1:
            raise ValueError("freqs must be 1-D when provided")
        if freqs.size != spec.selection_mask.sum():
            raise ValueError(
                "freqs length must match selected frequency count"
            )
        if freqs.size >= 2:
            diff_freqs = np.diff(freqs)
            positive_diffs = diff_freqs[diff_freqs > 0]
            if positive_diffs.size > 0:
                fine_spacing = float(np.median(positive_diffs))

    if spec.n_bins_high == 0:
        weights = np.ones_like(power_low, dtype=np.float64)
        return power_low, weights

    power_high = power[spec.mask_high]
    power_high_sorted = power_high[spec.sort_indices]
    sorted_bins = spec.bin_indices[spec.sort_indices]

    sum_power = np.bincount(
        sorted_bins,
        weights=power_high_sorted,
        minlength=spec.n_bins_high,
    )
    counts = spec.bin_counts.astype(np.float64)
    means = np.divide(
        sum_power,
        np.where(counts > 0, counts, 1),
        out=np.zeros_like(sum_power),
        where=counts > 0,
    )

    bin_widths = spec.bin_widths.astype(np.float64, copy=False)
    if bin_widths.size != spec.n_bins_high:
        raise ValueError("bin_widths length must equal n_bins_high")

    if fine_spacing <= 0:
        raise ValueError("fine_spacing must be positive")

    new_weights_high = bin_widths / fine_spacing
    # Renormalize to match exact count of fine frequencies for consistency
    total_expected = float(spec.bin_counts.sum())
    total_current = (
        float(new_weights_high.sum()) if new_weights_high.size else 0.0
    )
    if total_current > 0 and total_expected > 0:
        new_weights_high *= total_expected / total_current
    # Do not enforce monotonicity in weights used for likelihood; small
    # non-monotonicity due to edge quantization is expected and correct.

    power_coarse = np.concatenate((power_low, means))
    weights = np.concatenate(
        (
            np.ones_like(power_low, dtype=np.float64),
            new_weights_high,
        )
    )
    return power_coarse, weights
