from unittest.mock import patch

import pytest

from ddtrace.appsec._iast._iast_env import _get_iast_env
from ddtrace.appsec._iast._iast_request_context_base import set_iast_request_endpoint
from ddtrace.appsec._iast._taint_tracking._context import debug_context_array_free_slots_number
from ddtrace.appsec._iast.constants import VULN_CMDI
from ddtrace.appsec._iast.constants import VULN_PATH_TRAVERSAL
from ddtrace.appsec._iast.constants import VULN_SQL_INJECTION
from ddtrace.appsec._iast.constants import VULN_SSRF
from ddtrace.appsec._iast.constants import VULN_XSS
from ddtrace.appsec._iast.sampling.vulnerability_detection import init_request_vulnerability_maps
from ddtrace.appsec._iast.sampling.vulnerability_detection import reset_request_vulnerabilities
from ddtrace.appsec._iast.sampling.vulnerability_detection import should_process_vulnerability
from ddtrace.appsec._iast.sampling.vulnerability_detection import update_global_vulnerability_limit
from ddtrace.appsec._iast.taint_sinks._base import VulnerabilityBase
from tests.appsec.iast.iast_utils import _end_iast_context_and_oce
from tests.appsec.iast.iast_utils import _start_iast_context_and_oce
from tests.utils import override_global_config


def _get_global_limit():
    from ddtrace.appsec._iast.sampling.vulnerability_detection import GLOBAL_VULNERABILITIES_LIMIT

    return GLOBAL_VULNERABILITIES_LIMIT


@pytest.fixture(autouse=True)
def iast_create_context():
    with override_global_config(dict(_iast_enabled=True, _iast_deduplication_enabled=True, _iast_request_sampling=100)):
        assert debug_context_array_free_slots_number() > 0
        _start_iast_context_and_oce()
        try:
            yield
        finally:
            _end_iast_context_and_oce()
            reset_request_vulnerabilities()


def test_request_1_endpoint_with_6_vuln():
    """Test that higher occurrences of vulnerabilities are processed."""

    # First request - detect SQL injection once
    assert _get_global_limit() == {}
    set_iast_request_endpoint("GET", "/users")

    assert should_process_vulnerability(VULN_SQL_INJECTION) is True
    assert should_process_vulnerability(VULN_PATH_TRAVERSAL) is True
    # --
    assert should_process_vulnerability(VULN_SQL_INJECTION) is False
    assert should_process_vulnerability(VULN_SSRF) is False
    # --
    assert should_process_vulnerability(VULN_CMDI) is False
    assert should_process_vulnerability(VULN_SQL_INJECTION) is False

    env = _get_iast_env()
    assert env.is_first_vulnerability is False
    assert env.vulnerability_budget == 2
    assert env.vulnerabilities_request_limit[VULN_SQL_INJECTION] == 1
    assert env.vulnerabilities_request_limit[VULN_PATH_TRAVERSAL] == 1

    # Finish first request, update and reset
    update_global_vulnerability_limit()
    reset_request_vulnerabilities()

    # Second request - first occurrence skipped, second occurrence processed
    assert env.is_first_vulnerability is True
    assert env.vulnerability_budget == 0
    assert env.vulnerabilities_request_limit == {}
    assert _get_global_limit() == {"GET:/users": {"SQL_INJECTION": 1, "PATH_TRAVERSAL": 1}}

    assert should_process_vulnerability(VULN_SQL_INJECTION) is False
    assert should_process_vulnerability(VULN_PATH_TRAVERSAL) is False
    # --
    assert should_process_vulnerability(VULN_SQL_INJECTION) is True
    assert should_process_vulnerability(VULN_SSRF) is True
    # --
    assert should_process_vulnerability(VULN_CMDI) is False
    assert should_process_vulnerability(VULN_SQL_INJECTION) is False

    # Verify budget count
    env = _get_iast_env()
    assert env.is_first_vulnerability is False
    assert env.vulnerability_budget == 2
    assert env.vulnerabilities_request_limit[VULN_SQL_INJECTION] == 2
    assert env.vulnerabilities_request_limit[VULN_PATH_TRAVERSAL] == 1
    assert env.vulnerabilities_request_limit[VULN_SSRF] == 1

    # Finish first request, update and reset
    assert _get_global_limit() == {"GET:/users": {"SQL_INJECTION": 1, "PATH_TRAVERSAL": 1}}
    update_global_vulnerability_limit()
    reset_request_vulnerabilities()
    assert _get_global_limit() == {"GET:/users": {"PATH_TRAVERSAL": 1, "SQL_INJECTION": 2, "SSRF": 1}}


def test_multiple_endpoints():
    """Test vulnerability detection across multiple endpoints."""

    assert _get_global_limit() == {}
    # First endpoint
    set_iast_request_endpoint("GET", "/users")
    assert should_process_vulnerability(VULN_SQL_INJECTION) is True
    assert should_process_vulnerability(VULN_PATH_TRAVERSAL) is True
    update_global_vulnerability_limit()
    reset_request_vulnerabilities()

    # Second endpoint
    set_iast_request_endpoint("POST", "/users")
    assert should_process_vulnerability(VULN_SQL_INJECTION) is True
    assert should_process_vulnerability(VULN_XSS) is True
    update_global_vulnerability_limit()
    reset_request_vulnerabilities()

    # Third endpoint
    set_iast_request_endpoint("PUT", "/users/1")
    assert should_process_vulnerability(VULN_SSRF) is True
    assert should_process_vulnerability(VULN_CMDI) is True
    update_global_vulnerability_limit()
    reset_request_vulnerabilities()

    # Verify global map contains all endpoints
    assert len(_get_global_limit()) == 3, f"incorrect number of endpoints: {_get_global_limit()}"
    assert "GET:/users" in _get_global_limit()
    assert "POST:/users" in _get_global_limit()
    assert "PUT:/users/1" in _get_global_limit()

    # Return to first endpoint, verify deduplication
    set_iast_request_endpoint("GET", "/users")
    assert should_process_vulnerability(VULN_SQL_INJECTION) is False  # Already seen
    assert should_process_vulnerability(VULN_PATH_TRAVERSAL) is False  # Already seen
    assert should_process_vulnerability(VULN_XSS) is True  # New for this endpoint


def test_budget_exhaustion():
    """Test behavior when vulnerability budget is exhausted."""

    assert _get_global_limit() == {}

    set_iast_request_endpoint("GET", "/api/items")

    # Use up the budget (BUDGET_LIMIT is 2)
    assert should_process_vulnerability(VULN_SQL_INJECTION) is True
    assert should_process_vulnerability(VULN_PATH_TRAVERSAL) is True

    # Additional vulnerabilities should be rejected due to budget
    assert should_process_vulnerability(VULN_XSS) is False
    assert should_process_vulnerability(VULN_SSRF) is False
    assert should_process_vulnerability(VULN_CMDI) is False

    # Verify budget count
    env = _get_iast_env()
    assert env.vulnerability_budget == 2

    # Complete request and start a new one
    update_global_vulnerability_limit()
    reset_request_vulnerabilities()

    # New request should have fresh budget
    assert env.vulnerability_budget == 0
    assert should_process_vulnerability(VULN_SQL_INJECTION) is False  # Skipped (already known)
    assert should_process_vulnerability(VULN_XSS) is True  # New, should be processed


def test_lru_cache_limit():
    """Test LRU cache behavior when MAX_ENDPOINTS is reached."""

    assert _get_global_limit() == {}

    # Use a smaller MAX_ENDPOINTS value for testing
    with patch("ddtrace.appsec._iast.sampling.vulnerability_detection.MAX_ENDPOINTS", 3):
        # Fill the cache with endpoints
        for i in range(5):
            endpoint = f"/endpoint{i}"
            set_iast_request_endpoint("GET", endpoint)
            assert should_process_vulnerability(VULN_SQL_INJECTION) is True
            assert should_process_vulnerability(VULN_SQL_INJECTION) is True
            update_global_vulnerability_limit()
            reset_request_vulnerabilities()

        # Only the 3 most recently accessed endpoints should remain
        assert len(_get_global_limit()) == 3
        assert "GET:/endpoint0" not in _get_global_limit()
        assert "GET:/endpoint1" not in _get_global_limit()
        assert "GET:/endpoint2" in _get_global_limit()
        assert "GET:/endpoint3" in _get_global_limit()
        assert "GET:/endpoint4" in _get_global_limit()


def test_same_vulnerability_multiple_times():
    """Test detecting the same vulnerability type multiple times in one request."""

    assert _get_global_limit() == {}

    set_iast_request_endpoint("GET", "/search")

    # First occurrence processed
    assert should_process_vulnerability(VULN_SQL_INJECTION) is True

    # Same vulnerability reported again
    assert should_process_vulnerability(VULN_SQL_INJECTION) is True

    # And again
    assert should_process_vulnerability(VULN_SQL_INJECTION) is False

    env = _get_iast_env()
    assert env.vulnerabilities_request_limit[VULN_SQL_INJECTION] == 2
    assert env.vulnerability_budget == 2

    update_global_vulnerability_limit()

    # Global map should have the total count
    assert _get_global_limit()["GET:/search"][VULN_SQL_INJECTION] == 2


def test_mixed_vulnerability_types():
    """Test different vulnerability types in the same request."""

    assert _get_global_limit() == {}

    set_iast_request_endpoint("POST", "/comments")

    # Process different types until budget is exhausted
    assert should_process_vulnerability(VULN_XSS) is True
    assert should_process_vulnerability(VULN_SQL_INJECTION) is True
    assert should_process_vulnerability(VULN_PATH_TRAVERSAL) is False  # Budget exhausted

    env = _get_iast_env()
    assert env.vulnerability_budget == 2
    assert VULN_XSS in env.vulnerabilities_request_limit
    assert VULN_SQL_INJECTION in env.vulnerabilities_request_limit
    assert VULN_PATH_TRAVERSAL not in env.vulnerabilities_request_limit

    update_global_vulnerability_limit()

    # Global map should have all types
    assert _get_global_limit()["POST:/comments"][VULN_XSS] == 1
    assert _get_global_limit()["POST:/comments"][VULN_SQL_INJECTION] == 1
    assert VULN_PATH_TRAVERSAL not in _get_global_limit()["POST:/comments"]


def test_init_request_vulnerability_maps():
    """Test initialization of request vulnerability maps."""

    assert _get_global_limit() == {}

    # First add some data to the global map
    set_iast_request_endpoint("GET", "/product")
    assert should_process_vulnerability(VULN_SQL_INJECTION) is True
    assert should_process_vulnerability(VULN_XSS) is True
    update_global_vulnerability_limit()
    reset_request_vulnerabilities()

    # Now test init_request_vulnerability_maps directly
    env = _get_iast_env()
    copy_map = init_request_vulnerability_maps(env)

    # Verify the copy contains the same data as the global map
    assert copy_map == _get_global_limit()["GET:/product"]
    assert VULN_SQL_INJECTION in copy_map
    assert VULN_XSS in copy_map

    # Verify modifying the copy doesn't affect the global map
    copy_map[VULN_CMDI] = 5
    assert VULN_CMDI not in _get_global_limit()["GET:/product"]

    # Test with a new endpoint that doesn't exist in the global map
    set_iast_request_endpoint("DELETE", "/user")
    env = _get_iast_env()
    empty_copy = init_request_vulnerability_maps(env)
    assert empty_copy == {}


def test_with_modified_max_vulnerabilities_config():
    """Test behavior with modified max vulnerabilities per request setting."""

    assert _get_global_limit() == {}

    with override_global_config(dict(_iast_max_vulnerabilities_per_requests=3)):
        set_iast_request_endpoint("GET", "/config_test")

        # With higher limit, we can process more vulnerabilities
        assert should_process_vulnerability(VULN_SQL_INJECTION) is True
        assert should_process_vulnerability(VULN_XSS) is True
        assert should_process_vulnerability(VULN_SSRF) is True

        # But still limited by the configured value
        assert should_process_vulnerability(VULN_PATH_TRAVERSAL) is False

        env = _get_iast_env()
        assert env.vulnerability_budget == 3

        # Global map should be updated with all processed vulnerabilities
        update_global_vulnerability_limit()
        assert len(_get_global_limit()["GET:/config_test"]) == 3


def test_quota_out_of_context():
    _end_iast_context_and_oce()
    env = _get_iast_env()
    assert env is None
    assert VulnerabilityBase.has_quota() is False
