import mock
import pydantic_ai
import pytest
from typing_extensions import TypedDict

from ddtrace.internal.utils.version import parse_version
from ddtrace.llmobs._utils import safe_json
from tests.contrib.pydantic_ai.utils import calculate_square_tool
from tests.contrib.pydantic_ai.utils import expected_agent_metadata
from tests.contrib.pydantic_ai.utils import expected_calculate_square_tool
from tests.contrib.pydantic_ai.utils import expected_foo_tool
from tests.contrib.pydantic_ai.utils import expected_run_agent_span_event
from tests.contrib.pydantic_ai.utils import expected_run_tool_span_event
from tests.contrib.pydantic_ai.utils import foo_tool
from tests.llmobs._utils import _expected_llmobs_non_llm_span_event


PYDANTIC_AI_VERSION = parse_version(pydantic_ai.__version__)


@pytest.mark.parametrize(
    "ddtrace_global_config",
    [dict(_llmobs_enabled=True, _llmobs_ml_app="<ml-app-name>", _llmobs_agentless_enabled=True)],
)
class TestLLMObsPydanticAI:
    async def test_agent_run(self, pydantic_ai, request_vcr, llmobs_events, mock_tracer):
        model_settings = {"max_tokens": 100, "temperature": 0.5}
        instructions = "dummy instructions"
        system_prompt = "dummy system prompt"
        with request_vcr.use_cassette("agent_iter.yaml"):
            agent = pydantic_ai.Agent(
                model="gpt-4o",
                name="test_agent",
                instructions=instructions,
                system_prompt=system_prompt,
                tools=[calculate_square_tool],
                model_settings=model_settings,
            )
            result = await agent.run("Hello, world!")
        span = mock_tracer.pop_traces()[0][0]
        assert len(llmobs_events) == 1
        assert llmobs_events[0] == expected_run_agent_span_event(
            span,
            result.output,
            instructions=instructions,
            system_prompt=system_prompt,
            model_settings=model_settings,
            tools=expected_calculate_square_tool(),
        )

    def test_agent_run_sync(self, pydantic_ai, request_vcr, llmobs_events, mock_tracer):
        with request_vcr.use_cassette("agent_iter.yaml"):
            agent = pydantic_ai.Agent(model="gpt-4o", name="test_agent")
            result = agent.run_sync("Hello, world!")
        span = mock_tracer.pop_traces()[0][0]
        assert len(llmobs_events) == 1
        assert llmobs_events[0] == expected_run_agent_span_event(span, result.output)

    async def test_agent_run_stream(self, pydantic_ai, request_vcr, llmobs_events, mock_tracer):
        output = ""
        with request_vcr.use_cassette("agent_run_stream.yaml"):
            agent = pydantic_ai.Agent(model="gpt-4o", name="test_agent")
            async with agent.run_stream("Hello, world!") as result:
                async for chunk in result.stream(debounce_by=None):
                    output = chunk
        span = mock_tracer.pop_traces()[0][0]
        assert len(llmobs_events) == 1
        assert llmobs_events[0] == expected_run_agent_span_event(span, output)

    @pytest.mark.parametrize("delta", [False, True])
    async def test_agent_run_stream_text(self, pydantic_ai, request_vcr, llmobs_events, mock_tracer, delta):
        """
        delta determines whether each chunk represents the entire output up to the current point or just the
        delta from the previous chunk
        """
        output = ""
        with request_vcr.use_cassette("agent_run_stream.yaml"):
            agent = pydantic_ai.Agent(model="gpt-4o", name="test_agent")
            async with agent.run_stream("Hello, world!") as result:
                async for chunk in result.stream_text(debounce_by=None, delta=delta):
                    output = output + chunk if delta else chunk
        span = mock_tracer.pop_traces()[0][0]
        assert len(llmobs_events) == 1
        assert llmobs_events[0] == expected_run_agent_span_event(span, output)

    async def test_agent_run_stream_structured(self, pydantic_ai, request_vcr, llmobs_events, mock_tracer):
        output = ""
        with request_vcr.use_cassette("agent_run_stream.yaml"):
            agent = pydantic_ai.Agent(model="gpt-4o", name="test_agent")
            async with agent.run_stream("Hello, world!") as result:
                async for chunk in result.stream_structured():
                    if chunk[1]:
                        output = chunk[0].parts[0].content
        span = mock_tracer.pop_traces()[0][0]
        assert len(llmobs_events) == 1
        assert llmobs_events[0] == expected_run_agent_span_event(span, output)

    async def test_agent_run_stream_get_output(self, pydantic_ai, request_vcr, llmobs_events, mock_tracer):
        output = ""
        with request_vcr.use_cassette("agent_run_stream.yaml"):
            agent = pydantic_ai.Agent(model="gpt-4o", name="test_agent")
            async with agent.run_stream("Hello, world!") as result:
                output = await result.get_output()
        span = mock_tracer.pop_traces()[0][0]
        assert len(llmobs_events) == 1
        assert llmobs_events[0] == expected_run_agent_span_event(span, output)

    async def test_agent_run_stream_with_tool(self, pydantic_ai, request_vcr, llmobs_events, mock_tracer):
        instructions = "Use the provided tool to calculate the square of 2."
        with request_vcr.use_cassette("agent_run_stream_with_tools.yaml"):
            agent = pydantic_ai.Agent(
                model="gpt-4o", name="test_agent", tools=[calculate_square_tool], instructions=instructions
            )
            async with agent.run_stream("What is the square of 2?") as result:
                async for chunk in result.stream():
                    output = chunk
        trace = mock_tracer.pop_traces()[0]
        agent_span = trace[0]
        tool_span = trace[1]
        assert len(llmobs_events) == 2
        assert llmobs_events[0] == expected_run_tool_span_event(tool_span)
        assert llmobs_events[1] == expected_run_agent_span_event(
            agent_span,
            output,
            input_value="What is the square of 2?",
            instructions=instructions,
            tools=expected_calculate_square_tool(),
        )

    async def test_agent_run_stream_structured_with_tool(self, pydantic_ai, request_vcr, llmobs_events, mock_tracer):
        class Output(TypedDict):
            original_number: int
            square: int

        instructions = "Use the provided tool to calculate the square of 2."
        with request_vcr.use_cassette("agent_run_stream_structured_with_tool.yaml"):
            agent = pydantic_ai.Agent(
                model="gpt-4o",
                name="test_agent",
                tools=[calculate_square_tool],
                instructions=instructions,
                output_type=Output,
            )
            async with agent.run_stream("What is the square of 2?") as result:
                async for chunk in result.stream_structured(debounce_by=None):
                    output = chunk
        trace = mock_tracer.pop_traces()[0]
        agent_span = trace[0]
        tool_span = trace[1]
        assert len(llmobs_events) == 2
        assert llmobs_events[0] == expected_run_tool_span_event(tool_span)
        assert llmobs_events[1] == expected_run_agent_span_event(
            agent_span,
            safe_json(output[0].parts[0].args, ensure_ascii=False),
            input_value="What is the square of 2?",
            instructions=instructions,
            tools=expected_calculate_square_tool(),
        )

    async def test_agent_run_stream_error(self, pydantic_ai, request_vcr, llmobs_events, mock_tracer):
        output = ""
        with request_vcr.use_cassette("agent_run_stream.yaml"):
            agent = pydantic_ai.Agent(model="gpt-4o", name="test_agent")
            with pytest.raises(Exception, match="test error"):
                async with agent.run_stream("Hello, world!") as result:
                    stream = result.stream(debounce_by=None)
                    async for chunk in stream:
                        output = chunk
                        raise Exception("test error")

        span = mock_tracer.pop_traces()[0][0]
        assert len(llmobs_events) == 1
        assert llmobs_events[0] == _expected_llmobs_non_llm_span_event(
            span,
            "agent",
            input_value="Hello, world!",
            output_value=output,
            metadata=expected_agent_metadata(),
            tags={"ml_app": "<ml-app-name>", "service": "tests.contrib.pydantic_ai"},
            error="builtins.Exception",
            error_message="test error",
            error_stack=mock.ANY,
        )

    async def test_agent_iter(self, pydantic_ai, request_vcr, llmobs_events, mock_tracer):
        output = ""
        with request_vcr.use_cassette("agent_iter.yaml"):
            agent = pydantic_ai.Agent(model="gpt-4o", name="test_agent")
            async with agent.iter("Hello, world!") as agent_run:
                async for _ in agent_run:
                    pass
                output = agent_run.result.output
        span = mock_tracer.pop_traces()[0][0]
        assert len(llmobs_events) == 1
        assert llmobs_events[0] == expected_run_agent_span_event(span, output)

    async def test_agent_iter_error(self, pydantic_ai, request_vcr, llmobs_events):
        with request_vcr.use_cassette("agent_iter.yaml"):
            agent = pydantic_ai.Agent(model="gpt-4o", name="test_agent")
            with pytest.raises(Exception, match="test error"):
                async with agent.iter("Hello, world!") as agent_run:
                    async for _ in agent_run:
                        raise Exception("test error")

        assert len(llmobs_events) == 1
        assert llmobs_events[0]["status"] == "error"
        assert llmobs_events[0]["meta"]["error"]["message"] == "test error"

    @pytest.mark.skipif(PYDANTIC_AI_VERSION < (0, 4, 4), reason="pydantic-ai < 0.4.4 does not support toolsets")
    async def test_agent_run_with_toolset(self, pydantic_ai, request_vcr, llmobs_events, mock_tracer):
        """Test that the agent manifest includes tools from both the function toolset and the user-defined toolsets"""
        from pydantic_ai.toolsets import FunctionToolset

        with request_vcr.use_cassette("agent_run_stream_with_toolset.yaml"):
            agent = pydantic_ai.Agent(
                model="gpt-4o",
                name="test_agent",
                toolsets=[FunctionToolset(tools=[calculate_square_tool])],
                tools=[foo_tool],
            )
            result = await agent.run("Hello, world!")
        span = mock_tracer.pop_traces()[0][0]
        assert len(llmobs_events) == 1
        assert llmobs_events[0] == expected_run_agent_span_event(
            span, result.output, tools=expected_calculate_square_tool() + expected_foo_tool()
        )


class TestLLMObsPydanticAISpanLinks:
    async def test_agent_calls_tool(self, pydantic_ai, request_vcr, llmobs_events, openai_patched):
        instructions = "Use the provided tool to calculate the square of 2."
        with request_vcr.use_cassette("agent_run_stream_with_tools.yaml"):
            agent = pydantic_ai.Agent(
                model="gpt-4o", name="test_agent", tools=[calculate_square_tool], instructions=instructions
            )
            async with agent.run_stream("What is the square of 2?") as result:
                async for _ in result.stream(debounce_by=None):
                    pass

        assert len(llmobs_events) == 4
        first_llm_span = llmobs_events[0]
        tool_span = llmobs_events[1]
        second_llm_span = llmobs_events[2]
        # LLM to tool span link should be present on the tool span
        assert len(tool_span["span_links"]) == 1
        assert tool_span["span_links"][0]["span_id"] == first_llm_span["span_id"]
        assert tool_span["span_links"][0]["attributes"] == {"from": "output", "to": "input"}
        # tool to LLM span link should be present on the second LLM span
        assert len(second_llm_span["span_links"]) == 1
        assert second_llm_span["span_links"][0]["span_id"] == tool_span["span_id"]
        assert second_llm_span["span_links"][0]["attributes"] == {"from": "output", "to": "input"}
