import asyncio
import uuid
from collections.abc import AsyncGenerator, AsyncIterator, Callable
from contextlib import suppress
from copy import deepcopy
from dataclasses import dataclass
from datetime import timedelta
from inspect import iscoroutinefunction
from types import ModuleType, SimpleNamespace
from typing import ClassVar, Generic, TypeVar, cast, overload

from pydantic import BaseModel, Field
from typing_extensions import Self

from ragbits import agents
from ragbits.agents.exceptions import (
    AgentInvalidPromptInputError,
    AgentMaxTokensExceededError,
    AgentMaxTurnsExceededError,
    AgentNextPromptOverLimitError,
    AgentToolDuplicateError,
    AgentToolExecutionError,
    AgentToolNotAvailableError,
    AgentToolNotSupportedError,
)
from ragbits.agents.mcp.server import MCPServer, MCPServerStdio, MCPServerStreamableHttp
from ragbits.agents.mcp.utils import get_tools
from ragbits.agents.tool import Tool, ToolCallResult, ToolChoice
from ragbits.core.audit.traces import trace
from ragbits.core.llms.base import LLM, LLMClientOptionsT, LLMResponseWithMetadata, ToolCall, Usage
from ragbits.core.options import Options
from ragbits.core.prompt.base import BasePrompt, ChatFormat, SimplePrompt
from ragbits.core.prompt.prompt import Prompt, PromptInputT, PromptOutputT
from ragbits.core.types import NOT_GIVEN, NotGiven
from ragbits.core.utils.config_handling import ConfigurableComponent
from ragbits.core.utils.decorators import requires_dependencies

with suppress(ImportError):
    from a2a.types import AgentCapabilities, AgentCard, AgentSkill
    from pydantic_ai import Agent as PydanticAIAgent
    from pydantic_ai import mcp

    from ragbits.core.llms import LiteLLM


@dataclass
class AgentResult(Generic[PromptOutputT]):
    """
    Result of the agent run.
    """

    content: PromptOutputT
    """The output content of the agent."""
    metadata: dict
    """The additional data returned by the agent."""
    history: ChatFormat
    """The history of the agent."""
    tool_calls: list[ToolCallResult] | None = None
    """Tool calls run by the agent."""
    usage: Usage = Field(default_factory=Usage)
    """The token usage of the agent run."""


class AgentOptions(Options, Generic[LLMClientOptionsT]):
    """
    Options for the agent run.
    """

    llm_options: LLMClientOptionsT | None | NotGiven = NOT_GIVEN
    """The options for the LLM."""
    max_turns: int | None | NotGiven = NOT_GIVEN
    """The maximum number of turns the agent can take, if NOT_GIVEN,
    it defaults to 10, if None, agent will run forever"""
    max_total_tokens: int | None | NotGiven = NOT_GIVEN
    """The maximum number of tokens the agent can use, if NOT_GIVEN
    or None, agent will run forever"""
    max_prompt_tokens: int | None | NotGiven = NOT_GIVEN
    """The maximum number of prompt tokens the agent can use, if NOT_GIVEN
    or None, agent will run forever"""
    max_completion_tokens: int | None | NotGiven = NOT_GIVEN
    """The maximum number of completion tokens the agent can use, if NOT_GIVEN
    or None, agent will run forever"""


DepsT = TypeVar("DepsT")


class AgentDependencies(BaseModel, Generic[DepsT]):
    """
    Container for agent runtime dependencies.

    Becomes immutable after first attribute access.
    """

    model_config = {"arbitrary_types_allowed": True}

    _frozen: bool
    _value: DepsT | None

    def __init__(self, value: DepsT | None = None) -> None:
        super().__init__()
        self._value = value
        self._frozen = False

    def __setattr__(self, name: str, value: object) -> None:
        is_frozen = False
        if name != "_frozen":
            try:
                is_frozen = object.__getattribute__(self, "_frozen")
            except AttributeError:
                is_frozen = False

        if is_frozen and name not in {"_frozen"}:
            raise RuntimeError("Dependencies are immutable after first access")

        super().__setattr__(name, value)

    @property
    def value(self) -> DepsT | None:
        return self._value

    @value.setter
    def value(self, value: DepsT) -> None:
        if self._frozen:
            raise RuntimeError("Dependencies are immutable after first access")
        self._value = value

    def _freeze(self) -> None:
        if not self._frozen:
            self._frozen = True

    def __getattr__(self, name: str) -> object:
        value = object.__getattribute__(self, "_value")
        if value is None:
            raise AttributeError(name)
        self._freeze()
        return getattr(value, name)

    def __contains__(self, key: str) -> bool:
        value = object.__getattribute__(self, "_value")
        return hasattr(value, key) if value is not None else False


class AgentRunContext(BaseModel, Generic[DepsT]):
    """Context for the agent run."""

    deps: AgentDependencies[DepsT] = Field(default_factory=lambda: AgentDependencies())
    """Container for external dependencies."""
    usage: Usage = Field(default_factory=Usage)
    """The usage of the agent."""


class AgentResultStreaming(AsyncIterator[str | ToolCall | ToolCallResult]):
    """
    An async iterator that will collect all yielded items by LLM.generate_streaming(). This object is returned
    by `run_streaming`. It can be used in an `async for` loop to process items as they arrive. After the loop completes,
    all items are available under the same names as in AgentResult class.
    """

    def __init__(
        self, generator: AsyncGenerator[str | ToolCall | ToolCallResult | SimpleNamespace | BasePrompt | Usage]
    ):
        self._generator = generator
        self.content: str = ""
        self.tool_calls: list[ToolCallResult] | None = None
        self.metadata: dict = {}
        self.history: ChatFormat
        self.usage: Usage = Usage()

    def __aiter__(self) -> AsyncIterator[str | ToolCall | ToolCallResult]:
        return self

    async def __anext__(self) -> str | ToolCall | ToolCallResult:
        try:
            item = await self._generator.__anext__()
            match item:
                case str():
                    self.content += item
                case ToolCall():
                    pass
                case ToolCallResult():
                    if self.tool_calls is None:
                        self.tool_calls = []
                    self.tool_calls.append(item)
                case BasePrompt():
                    item.add_assistant_message(self.content)
                    self.history = item.chat
                    item = await self._generator.__anext__()
                    item = cast(SimpleNamespace, item)
                    item.result = {
                        "content": self.content,
                        "metadata": self.metadata,
                        "tool_calls": self.tool_calls,
                    }
                    raise StopAsyncIteration
                case Usage():
                    self.usage = item
                    return await self.__anext__()
                case _:
                    raise ValueError(f"Unexpected item: {item}")
            return item
        except StopAsyncIteration:
            raise


class Agent(
    ConfigurableComponent[AgentOptions[LLMClientOptionsT]], Generic[LLMClientOptionsT, PromptInputT, PromptOutputT]
):
    """
    Agent class that orchestrates the LLM and the prompt, and can call tools.

    Current implementation is highly experimental, and the API is subject to change.
    """

    options_cls: type[AgentOptions] = AgentOptions
    default_module: ClassVar[ModuleType | None] = agents
    configuration_key: ClassVar[str] = "agent"

    def __init__(
        self,
        llm: LLM[LLMClientOptionsT],
        prompt: str | type[Prompt[PromptInputT, PromptOutputT]] | Prompt[PromptInputT, PromptOutputT] | None = None,
        *,
        history: ChatFormat | None = None,
        keep_history: bool = False,
        tools: list[Callable] | None = None,
        mcp_servers: list[MCPServer] | None = None,
        default_options: AgentOptions[LLMClientOptionsT] | None = None,
    ) -> None:
        """
        Initialize the agent instance.

        Args:
            llm: The LLM to run the agent.
            prompt: The prompt for the agent. Can be:
                - str: A string prompt that will be used as system message when combined with string input,
                    or as the user message when no input is provided during run().
                - type[Prompt]: A structured prompt class that will be instantiated with the input.
                - Prompt: Already instantiated prompt instance
                - None: No predefined prompt. The input provided to run() will be used as the complete prompt.
            history: The history of the agent.
            keep_history: Whether to keep the history of the agent.
            tools: The tools available to the agent.
            mcp_servers: The MCP servers available to the agent.
            default_options: The default options for the agent run.
        """
        super().__init__(default_options)
        self.id = uuid.uuid4().hex[:8]
        self.llm = llm
        self.prompt = prompt
        self.tools = [Tool.from_callable(tool) for tool in tools or []]
        self.mcp_servers = mcp_servers or []
        self.history = history or []
        self.keep_history = keep_history

    @overload
    async def run(
        self: "Agent[LLMClientOptionsT, None, PromptOutputT]",
        input: str | None = None,
        options: AgentOptions[LLMClientOptionsT] | None = None,
        context: AgentRunContext | None = None,
        tool_choice: ToolChoice | None = None,
    ) -> AgentResult[PromptOutputT]: ...

    @overload
    async def run(
        self: "Agent[LLMClientOptionsT, PromptInputT, PromptOutputT]",
        input: PromptInputT,
        options: AgentOptions[LLMClientOptionsT] | None = None,
        context: AgentRunContext | None = None,
        tool_choice: ToolChoice | None = None,
    ) -> AgentResult[PromptOutputT]: ...

    async def run(
        self,
        input: str | PromptInputT | None = None,
        options: AgentOptions[LLMClientOptionsT] | None = None,
        context: AgentRunContext | None = None,
        tool_choice: ToolChoice | None = None,
    ) -> AgentResult[PromptOutputT]:
        """
        Run the agent. The method is experimental, inputs and outputs may change in the future.

        Args:
            input: The input for the agent run. Can be:
                - str: A string input that will be used as user message.
                - PromptInputT: Structured input for use with structured prompt classes.
                - None: No input. Only valid when a string prompt was provided during initialization.
            options: The options for the agent run.
            context: The context for the agent run.
            tool_choice: Parameter that allows to control what tool is used at first call. Can be one of:
                - "auto": let model decide if tool call is needed
                - "none": do not call tool
                - "required: enforce tool usage (model decides which one)
                - Callable: one of provided tools

        Returns:
            The result of the agent run.

        Raises:
            AgentToolDuplicateError: If the tool names are duplicated.
            AgentToolNotSupportedError: If the selected tool type is not supported.
            AgentToolNotAvailableError: If the selected tool is not available.
            AgentInvalidPromptInputError: If the prompt/input combination is invalid.
            AgentMaxTurnsExceededError: If the maximum number of turns is exceeded.
        """
        if context is None:
            context = AgentRunContext()

        input = cast(PromptInputT, input)
        merged_options = (self.default_options | options) if options else self.default_options
        llm_options = merged_options.llm_options or self.llm.default_options

        prompt_with_history = self._get_prompt_with_history(input)
        tools_mapping = await self._get_all_tools()
        tool_calls = []

        turn_count = 0
        max_turns = merged_options.max_turns
        max_turns = 10 if max_turns is NOT_GIVEN else max_turns
        with trace(input=input, options=merged_options) as outputs:
            while not max_turns or turn_count < max_turns:
                self._check_token_limits(merged_options, context.usage, prompt_with_history, self.llm)
                response = cast(
                    LLMResponseWithMetadata[PromptOutputT],
                    await self.llm.generate_with_metadata(
                        prompt=prompt_with_history,
                        tools=[tool.to_function_schema() for tool in tools_mapping.values()],
                        tool_choice=tool_choice if tool_choice and turn_count == 0 else None,
                        options=self._get_llm_options(llm_options, merged_options, context.usage),
                    ),
                )
                context.usage += response.usage or Usage()

                if not response.tool_calls:
                    break

                for tool_call in response.tool_calls:
                    result = await self._execute_tool(tool_call=tool_call, tools_mapping=tools_mapping, context=context)
                    tool_calls.append(result)

                    prompt_with_history = prompt_with_history.add_tool_use_message(**result.__dict__)

                turn_count += 1
            else:
                raise AgentMaxTurnsExceededError(cast(int, max_turns))

            outputs.result = {
                "content": response.content,
                "metadata": response.metadata,
                "tool_calls": tool_calls or None,
            }

            prompt_with_history = prompt_with_history.add_assistant_message(response.content)

            if self.keep_history:
                self.history = prompt_with_history.chat

            return AgentResult(
                content=response.content,
                metadata=response.metadata,
                tool_calls=tool_calls or None,
                history=prompt_with_history.chat,
                usage=context.usage,
            )

    @overload
    def run_streaming(
        self: "Agent[LLMClientOptionsT, None, PromptOutputT]",
        input: str | None = None,
        options: AgentOptions[LLMClientOptionsT] | None = None,
        context: AgentRunContext | None = None,
        tool_choice: ToolChoice | None = None,
    ) -> AgentResultStreaming: ...

    @overload
    def run_streaming(
        self: "Agent[LLMClientOptionsT, PromptInputT, PromptOutputT]",
        input: PromptInputT,
        options: AgentOptions[LLMClientOptionsT] | None = None,
        context: AgentRunContext | None = None,
        tool_choice: ToolChoice | None = None,
    ) -> AgentResultStreaming: ...

    def run_streaming(
        self,
        input: str | PromptInputT | None = None,
        options: AgentOptions[LLMClientOptionsT] | None = None,
        context: AgentRunContext | None = None,
        tool_choice: ToolChoice | None = None,
    ) -> AgentResultStreaming:
        """
        This method returns an `AgentResultStreaming` object that can be asynchronously
        iterated over. After the loop completes, all items are available under the same names as in AgentResult class.

        Args:
            input: The input for the agent run.
            options: The options for the agent run.
            context: The context for the agent run.
            tool_choice: Parameter that allows to control what tool is used at first call. Can be one of:
                - "auto": let model decide if tool call is needed
                - "none": do not call tool
                - "required: enforce tool usage (model decides which one)
                - Callable: one of provided tools

        Returns:
            A `StreamingResult` object for iteration and collection.

        Raises:
            AgentToolDuplicateError: If the tool names are duplicated.
            AgentToolNotSupportedError: If the selected tool type is not supported.
            AgentToolNotAvailableError: If the selected tool is not available.
            AgentInvalidPromptInputError: If the prompt/input combination is invalid.
            AgentMaxTurnsExceededError: If the maximum number of turns is exceeded.
        """
        generator = self._stream_internal(input, options, context, tool_choice)
        return AgentResultStreaming(generator)

    async def _stream_internal(
        self,
        input: str | PromptInputT | None = None,
        options: AgentOptions[LLMClientOptionsT] | None = None,
        context: AgentRunContext | None = None,
        tool_choice: ToolChoice | None = None,
    ) -> AsyncGenerator[str | ToolCall | ToolCallResult | SimpleNamespace | BasePrompt | Usage]:
        if context is None:
            context = AgentRunContext()

        input = cast(PromptInputT, input)
        merged_options = (self.default_options | options) if options else self.default_options
        llm_options = merged_options.llm_options or self.llm.default_options

        prompt_with_history = self._get_prompt_with_history(input)
        tools_mapping = await self._get_all_tools()
        turn_count = 0
        max_turns = merged_options.max_turns
        max_turns = 10 if max_turns is NOT_GIVEN else max_turns
        with trace(input=input, options=merged_options) as outputs:
            while not max_turns or turn_count < max_turns:
                returned_tool_call = False
                self._check_token_limits(merged_options, context.usage, prompt_with_history, self.llm)
                streaming_result = self.llm.generate_streaming(
                    prompt=prompt_with_history,
                    tools=[tool.to_function_schema() for tool in tools_mapping.values()],
                    tool_choice=tool_choice if tool_choice and turn_count == 0 else None,
                    options=self._get_llm_options(llm_options, merged_options, context.usage),
                )
                async for chunk in streaming_result:
                    yield chunk

                    if isinstance(chunk, ToolCall):
                        result = await self._execute_tool(tool_call=chunk, tools_mapping=tools_mapping, context=context)
                        yield result
                        prompt_with_history = prompt_with_history.add_tool_use_message(**result.__dict__)
                        returned_tool_call = True
                turn_count += 1
                if streaming_result.usage:
                    context.usage += streaming_result.usage

                if not returned_tool_call:
                    break
            else:
                raise AgentMaxTurnsExceededError(cast(int, max_turns))

            yield context.usage
            yield prompt_with_history
            if self.keep_history:
                self.history = prompt_with_history.chat
            yield outputs

    @staticmethod
    def _check_token_limits(
        options: AgentOptions[LLMClientOptionsT], usage: Usage, prompt: BasePrompt, llm: LLM[LLMClientOptionsT]
    ) -> None:
        if options.max_prompt_tokens or options.max_total_tokens:
            next_prompt_tokens = llm.count_tokens(prompt)
            if options.max_prompt_tokens and next_prompt_tokens > options.max_prompt_tokens - usage.prompt_tokens:
                raise AgentMaxTokensExceededError("prompt", options.max_prompt_tokens, next_prompt_tokens)
            if options.max_total_tokens and next_prompt_tokens > options.max_total_tokens - usage.total_tokens:
                raise AgentNextPromptOverLimitError(
                    "total", options.max_total_tokens, usage.total_tokens, next_prompt_tokens
                )

        if options.max_total_tokens and usage.total_tokens > options.max_total_tokens:
            raise AgentMaxTokensExceededError("total", options.max_total_tokens, usage.total_tokens)
        if options.max_prompt_tokens and usage.prompt_tokens > options.max_prompt_tokens:
            raise AgentMaxTokensExceededError("prompt", options.max_prompt_tokens, usage.prompt_tokens)
        if options.max_completion_tokens and usage.completion_tokens > options.max_completion_tokens:
            raise AgentMaxTokensExceededError("completion", options.max_completion_tokens, usage.completion_tokens)

    @staticmethod
    def _get_llm_options(
        llm_options: LLMClientOptionsT, options: AgentOptions[LLMClientOptionsT], usage: Usage
    ) -> LLMClientOptionsT:
        actual_limits: list[int] = [
            limit
            for limit in (options.max_total_tokens, options.max_prompt_tokens, options.max_completion_tokens)
            if isinstance(limit, int)
        ]

        if not actual_limits:
            return llm_options

        llm_options.max_tokens = min(actual_limits) - usage.total_tokens
        return llm_options

    def _get_prompt_with_history(self, input: PromptInputT) -> SimplePrompt | Prompt[PromptInputT, PromptOutputT]:
        curr_history = deepcopy(self.history)
        if isinstance(self.prompt, type) and issubclass(self.prompt, Prompt):
            if self.keep_history:
                self.prompt = self.prompt(input, curr_history)
                return self.prompt
            else:
                return self.prompt(input, curr_history)

        if isinstance(self.prompt, Prompt):
            self.prompt.add_user_message(input)
            return self.prompt

        if isinstance(self.prompt, str) and isinstance(input, str):
            system_prompt = {"role": "system", "content": self.prompt}
            if len(curr_history) == 0:
                curr_history.append(system_prompt)
            else:
                system_idx = next((i for i, msg in enumerate(curr_history) if msg["role"] == "system"), -1)
                if system_idx == -1:
                    curr_history.insert(0, system_prompt)
                else:
                    curr_history[system_idx] = system_prompt
            incoming_user_prompt = input
        elif isinstance(self.prompt, str) and input is None:
            incoming_user_prompt = self.prompt
        elif isinstance(input, str):
            incoming_user_prompt = input
        else:
            raise AgentInvalidPromptInputError(self.prompt, input)

        curr_history.append({"role": "user", "content": incoming_user_prompt})

        return SimplePrompt(curr_history)

    async def _get_all_tools(self) -> dict[str, Tool]:
        tools_mapping = {}
        all_tools = list(self.tools)

        server_tools = await asyncio.gather(*[get_tools(server) for server in self.mcp_servers])
        for tools in server_tools:
            all_tools.extend(tools)

        for tool in all_tools:
            if tool.name in tools_mapping:
                raise AgentToolDuplicateError(tool.name)
            tools_mapping[tool.name] = tool

        return tools_mapping

    async def _execute_tool(
        self, tool_call: ToolCall, tools_mapping: dict[str, Tool], context: AgentRunContext | None = None
    ) -> ToolCallResult:
        if tool_call.type != "function":
            raise AgentToolNotSupportedError(tool_call.type)

        if tool_call.name not in tools_mapping:
            raise AgentToolNotAvailableError(tool_call.name)

        tool = tools_mapping[tool_call.name]

        with trace(agent_id=self.id, tool_name=tool_call.name, tool_arguments=tool_call.arguments) as outputs:
            try:
                call_args = tool_call.arguments.copy()
                if tool.context_var_name:
                    call_args[tool.context_var_name] = context

                tool_output = (
                    await tool.on_tool_call(**call_args)
                    if iscoroutinefunction(tool.on_tool_call)
                    else tool.on_tool_call(**call_args)
                )

                outputs.result = {
                    "tool_output": tool_output,
                    "tool_call_id": tool_call.id,
                }
            except Exception as e:
                outputs.result = {
                    "error": str(e),
                    "tool_call_id": tool_call.id,
                }
                raise AgentToolExecutionError(tool_call.name, e) from e

        return ToolCallResult(
            id=tool_call.id,
            name=tool_call.name,
            arguments=tool_call.arguments,
            result=tool_output,
        )

    @requires_dependencies(["a2a.types"], "a2a")
    async def get_agent_card(
        self,
        name: str,
        description: str,
        version: str = "0.0.0",
        host: str = "127.0.0.1",
        port: int = 8000,
        protocol: str = "http",
        default_input_modes: list[str] | None = None,
        default_output_modes: list[str] | None = None,
        capabilities: "AgentCapabilities | None" = None,
        skills: list["AgentSkill"] | None = None,
    ) -> "AgentCard":
        """
        Create an AgentCard that encapsulates metadata about the agent,
        such as its name, version, description, network location, supported input/output modes,
        capabilities, and skills.

        Args:
            name: Human-readable name of the agent.
            description: A brief description of the agent.
            version: Version string of the agent. Defaults to "0.0.0".
            host: Hostname or IP where the agent will be served. Defaults to "0.0.0.0".
            port: Port number on which the agent listens. Defaults to 8000.
            protocol: URL scheme (e.g. "http" or "https"). Defaults to "http".
            default_input_modes: List of input content modes supported by the agent. Defaults to ["text"].
            default_output_modes: List of output content modes supported. Defaults to ["text"].
            capabilities: Agent capabilities; if None, defaults to empty capabilities.
            skills: List of AgentSkill objects representing the agent's skills.
                If None, attempts to extract skills from the agent's registered tools.

        Returns:
            An A2A-compliant agent descriptor including URL and capabilities.
        """
        return AgentCard(
            name=name,
            version=version,
            description=description,
            url=f"{protocol}://{host}:{port}",
            defaultInputModes=default_input_modes or ["text"],
            defaultOutputModes=default_output_modes or ["text"],
            skills=skills or await self._extract_agent_skills(),
            capabilities=capabilities or AgentCapabilities(),
        )

    async def _extract_agent_skills(self) -> list["AgentSkill"]:
        """
        The skill representation with name, id, description, and tags.
        """
        all_tools = await self._get_all_tools()
        return [
            AgentSkill(
                name=tool.name.replace("_", " ").title(),
                id=tool.name,
                description=f"{tool.description}\n\nParameters:\n{tool.parameters}",
                tags=[],
            )
            for tool in all_tools.values()
        ]

    @requires_dependencies("pydantic_ai")
    def to_pydantic_ai(self) -> "PydanticAIAgent":
        """
        Convert ragbits agent instance into a `pydantic_ai.Agent` representation.

        Returns:
            PydanticAIAgent: The equivalent Pydantic-based agent configuration.

        Raises:
            ValueError: If the `prompt` is not a string or a `Prompt` instance.
        """
        mcp_servers: list[mcp.MCPServerStdio | mcp.MCPServerHTTP] = []

        if not self.prompt:
            raise ValueError("Prompt is required but was None.")

        if isinstance(self.prompt, str):
            system_prompt = self.prompt
        else:
            if not self.prompt.system_prompt:
                raise ValueError("System prompt is required but was None.")
            system_prompt = self.prompt.system_prompt

        for mcp_server in self.mcp_servers:
            if isinstance(mcp_server, MCPServerStdio):
                mcp_servers.append(
                    mcp.MCPServerStdio(
                        command=mcp_server.params.command, args=mcp_server.params.args, env=mcp_server.params.env
                    )
                )
            elif isinstance(mcp_server, MCPServerStreamableHttp):
                timeout = mcp_server.params["timeout"]
                sse_timeout = mcp_server.params["sse_read_timeout"]

                mcp_servers.append(
                    mcp.MCPServerHTTP(
                        url=mcp_server.params["url"],
                        headers=mcp_server.params["headers"],
                        timeout=timeout.total_seconds() if isinstance(timeout, timedelta) else timeout,
                        sse_read_timeout=sse_timeout.total_seconds()
                        if isinstance(sse_timeout, timedelta)
                        else sse_timeout,
                    )
                )
        return PydanticAIAgent(
            model=self.llm.model_name,
            system_prompt=system_prompt,
            tools=[tool.to_pydantic_ai() for tool in self.tools],
            mcp_servers=mcp_servers,
        )

    @classmethod
    @requires_dependencies("pydantic_ai")
    def from_pydantic_ai(cls, pydantic_ai_agent: "PydanticAIAgent") -> Self:
        """
        Construct an agent instance from a `pydantic_ai.Agent` representation.

        Args:
            pydantic_ai_agent: A Pydantic-based agent configuration.

        Returns:
            An instance of the agent class initialized from the Pydantic representation.
        """
        mcp_servers: list[MCPServerStdio | MCPServerStreamableHttp] = []
        for mcp_server in pydantic_ai_agent._mcp_servers:
            if isinstance(mcp_server, mcp.MCPServerStdio):
                mcp_servers.append(
                    MCPServerStdio(
                        params={
                            "command": mcp_server.command,
                            "args": list(mcp_server.args),
                            "env": mcp_server.env or {},
                        }
                    )
                )
            elif isinstance(mcp_server, mcp.MCPServerHTTP):
                headers = mcp_server.headers or {}

                mcp_servers.append(
                    MCPServerStreamableHttp(
                        params={
                            "url": mcp_server.url,
                            "headers": {str(k): str(v) for k, v in headers.items()},
                            "sse_read_timeout": mcp_server.sse_read_timeout,
                            "timeout": mcp_server.timeout,
                        }
                    )
                )

        if not pydantic_ai_agent.model:
            raise ValueError("Missing LLM in `pydantic_ai.Agent` instance")
        elif isinstance(pydantic_ai_agent.model, str):
            model_name = pydantic_ai_agent.model
        else:
            model_name = pydantic_ai_agent.model.model_name

        return cls(
            llm=LiteLLM(model_name=model_name),  # type: ignore[arg-type]
            prompt="\n".join(pydantic_ai_agent._system_prompts),
            tools=[tool.function for _, tool in pydantic_ai_agent._function_tools.items()],
            mcp_servers=cast(list[MCPServer], mcp_servers),
        )
