"""Constant-time operation tests for Shamir's Secret Sharing.

These tests attempt to verify that operations do not leak timing information
about secret data. While Python's CPython implementation makes true constant-time
guarantees impossible, these tests help identify obvious timing side-channels.

Note: These tests are statistical and may occasionally fail due to system noise.
They are marked as slow and can be skipped during normal development.
"""

import statistics
import time
from random import Random

import pytest

from shamir import combine, split
from shamir.math import add, div, mul


@pytest.mark.slow
class TestConstantTimeOperations:
    """Test that operations don't leak timing information about secrets."""

    def test_combine_timing_independent_of_secret_value(self) -> None:
        """Test that combine() timing doesn't correlate with secret byte values.

        Key insight: If timing varies based on secret values (e.g., all zeros vs
        all ones), an attacker could potentially learn information about the secret
        through timing side-channels.
        """
        threshold = 3
        num_trials = 100

        # Test different bit patterns that might cause timing variations
        test_secrets = {
            "all_zero": bytes([0x00]) * 100,
            "all_one": bytes([0xFF]) * 100,
            "alternating_01": bytes([0x55]) * 100,  # 01010101
            "alternating_10": bytes([0xAA]) * 100,  # 10101010
        }

        timings_by_secret = {}

        for secret_name, secret in test_secrets.items():
            timings = []

            for trial in range(num_trials):
                parts = split(secret, 5, threshold, rng=Random(trial))

                # Measure reconstruction time
                start = time.perf_counter_ns()
                combine(parts[:threshold])
                elapsed = time.perf_counter_ns() - start

                timings.append(elapsed)

            timings_by_secret[secret_name] = timings

        # Statistical analysis: Compare coefficient of variation (CV)
        # CV normalizes for different means, focusing on relative spread
        cvs = {
            name: statistics.stdev(times) / statistics.mean(times)
            for name, times in timings_by_secret.items()
        }

        # All CVs should be similar (timing shouldn't depend on secret)
        cv_values = list(cvs.values())
        cv_range = max(cv_values) - min(cv_values)

        # Note: Python's CPython makes true constant-time impossible
        # Allow 0.5 variation (50%) to account for Python's inherent variability
        # This test is best-effort and may be flaky in parallel mode
        assert cv_range < 0.5, (
            f"Timing varies significantly by secret value: {cvs}. "
            f"Range: {cv_range:.3f}. This may indicate timing side-channel."
        )

    def test_combine_timing_independent_of_secret_length(self) -> None:
        """Test that per-byte timing is consistent across different secret lengths."""
        threshold = 3
        num_trials = 50

        # Test different secret lengths
        secret_lengths = [10, 50, 100, 500, 1000]
        per_byte_timings = {}

        for length in secret_lengths:
            secret = b"X" * length
            timings = []

            for trial in range(num_trials):
                parts = split(secret, 5, threshold, rng=Random(trial))

                start = time.perf_counter_ns()
                combine(parts[:threshold])
                elapsed = time.perf_counter_ns() - start

                # Normalize by secret length
                per_byte = elapsed / length
                timings.append(per_byte)

            per_byte_timings[length] = statistics.mean(timings)

        # Per-byte timing should be relatively consistent
        timing_values = list(per_byte_timings.values())
        mean_timing = statistics.mean(timing_values)
        max_deviation = max(abs(t - mean_timing) / mean_timing for t in timing_values)

        # Allow 40% deviation (linear operations may have overhead)
        assert max_deviation < 0.4, (
            f"Per-byte timing varies significantly: {per_byte_timings}. "
            f"Max deviation: {max_deviation:.3f}"
        )

    def test_split_timing_independent_of_secret_value(self) -> None:
        """Test that split() timing doesn't correlate with secret byte values."""
        threshold = 3
        num_parts = 5
        num_trials = 100

        test_secrets = {
            "all_zero": bytes([0x00]) * 100,
            "all_one": bytes([0xFF]) * 100,
            "alternating_01": bytes([0x55]) * 100,
            "alternating_10": bytes([0xAA]) * 100,
        }

        timings_by_secret = {}

        for secret_name, secret in test_secrets.items():
            timings = []

            for trial in range(num_trials):
                start = time.perf_counter_ns()
                split(secret, num_parts, threshold, rng=Random(trial))
                elapsed = time.perf_counter_ns() - start

                timings.append(elapsed)

            timings_by_secret[secret_name] = timings

        # Compare coefficients of variation
        cvs = {
            name: statistics.stdev(times) / statistics.mean(times)
            for name, times in timings_by_secret.items()
        }

        cv_values = list(cvs.values())
        cv_range = max(cv_values) - min(cv_values)

        # Note: Python's CPython implementation has inherent timing variability
        # Allow 2.5 variation (250%) to account for parallel execution noise
        # This test is best-effort and may be flaky
        assert cv_range < 2.5, (
            f"Split timing varies significantly by secret value: {cvs}. "
            f"Range: {cv_range:.3f}. This may indicate timing side-channel."
        )

    @pytest.mark.parametrize(
        "secret_pair",
        [
            (b"\x00" * 100, b"\xff" * 100),  # Opposite extremes
            (b"\x00" * 100, b"\x01" * 100),  # Minimal difference
            (b"\x55" * 100, b"\xaa" * 100),  # Alternating patterns
        ],
    )
    def test_combine_timing_similar_for_different_secrets(
        self, secret_pair: tuple[bytes, bytes]
    ) -> None:
        """Test that combine timing is similar for different secrets of same length."""
        secret1, secret2 = secret_pair
        threshold = 3
        num_trials = 50

        timings1 = []
        timings2 = []

        for trial in range(num_trials):
            # Time secret1
            parts1 = split(secret1, 5, threshold, rng=Random(trial))
            start = time.perf_counter_ns()
            combine(parts1[:threshold])
            elapsed1 = time.perf_counter_ns() - start
            timings1.append(elapsed1)

            # Time secret2
            parts2 = split(secret2, 5, threshold, rng=Random(trial))
            start = time.perf_counter_ns()
            combine(parts2[:threshold])
            elapsed2 = time.perf_counter_ns() - start
            timings2.append(elapsed2)

        # Compare median times (more robust than mean)
        median1 = statistics.median(timings1)
        median2 = statistics.median(timings2)

        # Times should be within 25% of each other
        ratio = max(median1, median2) / min(median1, median2)
        assert ratio < 1.25, (
            f"Timing differs significantly for different secrets: "
            f"{median1} vs {median2} (ratio: {ratio:.3f})"
        )


@pytest.mark.slow
class TestGF256ConstantTime:
    """Test that GF(256) operations are constant-time."""

    def test_mul_timing_independent_of_operands(self) -> None:
        """Test that GF(256) multiplication time doesn't depend on operand values."""
        num_trials = 10000

        # Test different operand patterns
        operand_pairs = [
            (0, 0),  # Both zero
            (1, 1),  # Both one
            (255, 255),  # Both max
            (0, 255),  # Zero and max
            (85, 170),  # Alternating patterns
        ]

        timings_by_pair = {}

        for a, b in operand_pairs:
            timings = []

            for _ in range(num_trials):
                start = time.perf_counter_ns()
                mul(a, b)
                elapsed = time.perf_counter_ns() - start
                timings.append(elapsed)

            timings_by_pair[(a, b)] = timings

        # Compare coefficients of variation
        cvs = {
            pair: statistics.stdev(times) / statistics.mean(times)
            if statistics.mean(times) > 0
            else 0
            for pair, times in timings_by_pair.items()
        }

        cv_values = [cv for cv in cvs.values() if cv > 0]
        if cv_values:
            cv_range = max(cv_values) - min(cv_values)

            # Note: Very small operations like GF(256) mul have high measurement noise
            # in Python. Allow significant variation (50x) to account for this while
            # still detecting major issues. This test is best-effort only.
            assert cv_range < 50.0, (
                f"GF(256) mul timing varies significantly by operands: {cvs}. "
                f"Range: {cv_range:.3f}. This may indicate timing side-channel."
            )

    def test_div_timing_independent_of_operands(self) -> None:
        """Test that GF(256) division time doesn't depend on operand values."""
        num_trials = 10000

        # Test different operand patterns (avoid division by zero)
        operand_pairs = [
            (1, 1),
            (255, 255),
            (0, 255),
            (255, 1),
            (85, 170),
        ]

        timings_by_pair = {}

        for a, b in operand_pairs:
            timings = []

            for _ in range(num_trials):
                start = time.perf_counter_ns()
                div(a, b)
                elapsed = time.perf_counter_ns() - start
                timings.append(elapsed)

            timings_by_pair[(a, b)] = timings

        # Compare coefficients of variation
        cvs = {
            pair: statistics.stdev(times) / statistics.mean(times)
            if statistics.mean(times) > 0
            else 0
            for pair, times in timings_by_pair.items()
        }

        cv_values = [cv for cv in cvs.values() if cv > 0]
        if cv_values:
            cv_range = max(cv_values) - min(cv_values)

            # Note: Very small operations have high measurement noise in Python
            assert cv_range < 5.0, (
                f"GF(256) div timing varies significantly by operands: {cvs}. "
                f"Range: {cv_range:.3f}. This may indicate timing side-channel."
            )

    def test_add_timing_independent_of_operands(self) -> None:
        """Test that GF(256) addition time doesn't depend on operand values."""
        num_trials = 10000

        operand_pairs = [
            (0, 0),
            (1, 1),
            (255, 255),
            (0, 255),
            (85, 170),
        ]

        timings_by_pair = {}

        for a, b in operand_pairs:
            timings = []

            for _ in range(num_trials):
                start = time.perf_counter_ns()
                add(a, b)
                elapsed = time.perf_counter_ns() - start
                timings.append(elapsed)

            timings_by_pair[(a, b)] = timings

        # Compare median times
        medians = {
            pair: statistics.median(times) for pair, times in timings_by_pair.items()
        }

        median_values = list(medians.values())
        if max(median_values) > 0:
            ratio = max(median_values) / max(min(median_values), 1)

            # Addition should be very fast and consistent
            assert ratio < 2.0, (
                f"GF(256) add timing varies by operands: {medians}. Ratio: {ratio:.3f}"
            )


@pytest.mark.slow
class TestTimingDocumentation:
    """Document which operations must be constant-time."""

    def test_constant_time_requirements_documented(self) -> None:
        """Document which operations MUST be constant-time for security.

        This test exists primarily for documentation and verification that
        the required operations are implemented.
        """
        # Operations that MUST be constant-time for security
        constant_time_required = {
            "shamir.math.add": "GF(256) addition must not leak operand info",
            "shamir.math.mul": "GF(256) multiplication must not leak operand info",
            "shamir.math.div": "GF(256) division must not leak operand info",
            "shamir.math.inverse": "GF(256) inverse must not leak operand info",
            "shamir.utils.interpolate": "Lagrange interpolation must not leak secret",
            "shamir.combine": "Secret reconstruction must not leak secret value",
        }

        # Verify these functions exist
        from shamir import combine
        from shamir.math import add, div, inverse, mul
        from shamir.utils import interpolate

        operations = {
            "shamir.math.add": add,
            "shamir.math.mul": mul,
            "shamir.math.div": div,
            "shamir.math.inverse": inverse,
            "shamir.utils.interpolate": interpolate,
            "shamir.combine": combine,
        }

        # All required operations should exist
        for name, reason in constant_time_required.items():
            assert name in operations, f"{name} not found (required: {reason})"

        # Success: All constant-time operations are implemented
        assert len(operations) == len(constant_time_required)


class TestTimingAttackResistance:
    """Test resistance to known timing attack vectors."""

    def test_no_early_termination_on_zero(self) -> None:
        """Test that operations don't terminate early on zero values.

        Early termination on special values (like zero) is a common source
        of timing leaks.
        """
        threshold = 3
        num_trials = 100

        # Secret with many zeros
        secret_with_zeros = bytes([0] * 50 + [255] * 50)
        # Secret with no zeros
        secret_no_zeros = bytes([i % 255 + 1 for i in range(100)])

        timings_with_zeros = []
        timings_no_zeros = []

        for trial in range(num_trials):
            # Time secret with zeros
            parts1 = split(secret_with_zeros, 5, threshold, rng=Random(trial))
            start = time.perf_counter_ns()
            combine(parts1[:threshold])
            elapsed1 = time.perf_counter_ns() - start
            timings_with_zeros.append(elapsed1)

            # Time secret without zeros
            parts2 = split(secret_no_zeros, 5, threshold, rng=Random(trial))
            start = time.perf_counter_ns()
            combine(parts2[:threshold])
            elapsed2 = time.perf_counter_ns() - start
            timings_no_zeros.append(elapsed2)

        # Compare distributions
        median_with_zeros = statistics.median(timings_with_zeros)
        median_no_zeros = statistics.median(timings_no_zeros)

        ratio = max(median_with_zeros, median_no_zeros) / min(
            median_with_zeros, median_no_zeros
        )

        # Should be similar timing
        assert ratio < 1.3, (
            f"Timing differs for secrets with/without zeros: "
            f"{median_with_zeros} vs {median_no_zeros} (ratio: {ratio:.3f})"
        )

    def test_no_conditional_branching_on_secret_bits(self) -> None:
        """Test that timing doesn't correlate with Hamming weight (number of 1 bits)."""
        threshold = 3
        num_trials = 50

        # Secrets with different Hamming weights
        secrets = {
            "low_weight": bytes([0b00000001] * 100),  # Few 1 bits
            "mid_weight": bytes([0b00001111] * 100),  # Medium 1 bits
            "high_weight": bytes([0b11111111] * 100),  # Many 1 bits
        }

        timings_by_weight = {}

        for name, secret in secrets.items():
            timings = []

            for trial in range(num_trials):
                parts = split(secret, 5, threshold, rng=Random(trial))

                start = time.perf_counter_ns()
                combine(parts[:threshold])
                elapsed = time.perf_counter_ns() - start

                timings.append(elapsed)

            timings_by_weight[name] = timings

        # Compare medians
        medians = {
            name: statistics.median(times) for name, times in timings_by_weight.items()
        }

        median_values = list(medians.values())
        ratio = max(median_values) / min(median_values)

        # Timing shouldn't depend on Hamming weight
        assert ratio < 1.3, (
            f"Timing correlates with Hamming weight: {medians}. Ratio: {ratio:.3f}"
        )
