"""MCP client integration for LLMling agent."""

from __future__ import annotations

from contextlib import AsyncExitStack, suppress
import shutil
from typing import TYPE_CHECKING, Any, Self, TextIO

from llmling_agent.log import get_logger


if TYPE_CHECKING:
    from collections.abc import Awaitable, Callable
    from types import TracebackType

    import mcp
    from mcp import ClientSession
    from mcp.client.session import RequestContext
    from mcp.types import Tool, Tool as MCPTool

    from llmling_agent_config.mcp_server import TransportType

logger = get_logger(__name__)


def mcp_tool_to_fn_schema(tool: MCPTool) -> dict[str, Any]:
    """Convert MCP tool to OpenAI function schema."""
    desc = tool.description or "No description provided"
    return {"name": tool.name, "description": desc, "parameters": tool.inputSchema}


class MCPClient:
    """MCP client for communicating with MCP servers."""

    def __init__(
        self,
        transport_mode: TransportType = "stdio",
        elicitation_callback: Callable[
            [RequestContext, mcp.types.ElicitRequestParams],
            Awaitable[mcp.types.ElicitResult | mcp.types.ErrorData],
        ]
        | None = None,
        sampling_callback: Callable[
            [RequestContext, mcp.types.CreateMessageRequestParams],
            Awaitable[mcp.types.CreateMessageResult | mcp.types.ErrorData],
        ]
        | None = None,
    ):
        self.exit_stack = AsyncExitStack()
        self.session: ClientSession | None = None
        self._available_tools: list[Tool] = []
        self._old_stdout: TextIO | None = None
        self._transport_mode = transport_mode
        self._elicitation_callback = elicitation_callback
        self._sampling_callback = sampling_callback

    async def __aenter__(self) -> Self:
        """Enter context and redirect stdout if in stdio mode."""
        return self

    async def __aexit__(
        self,
        exc_type: type[BaseException] | None,
        exc_val: BaseException | None,
        exc_tb: TracebackType | None,
    ):
        """Restore stdout if redirected and cleanup."""
        try:
            await self.cleanup()
        except RuntimeError as e:
            if "exit cancel scope in a different task" in str(e):
                logger.warning("Ignoring known MCP cleanup issue: Task context mismatch")
            else:
                raise
        except Exception:
            logger.exception("Error during MCP client cleanup")
            raise

    async def cleanup(self):
        """Clean up resources."""
        with suppress(RuntimeError) as cm:
            await self.exit_stack.aclose()

        if cm and cm.error and "exit cancel scope in a different task" in str(cm.error):
            logger.warning("Ignoring known MCP cleanup issue: Task context mismatch")
        elif cm and cm.error:
            raise cm.error

    async def _default_elicitation_callback(
        self,
        context: RequestContext,
        params: mcp.types.ElicitRequestParams,
    ) -> mcp.types.ElicitResult | mcp.types.ErrorData:
        """Default elicitation callback that returns not supported."""
        import mcp

        return mcp.types.ErrorData(
            code=mcp.types.INVALID_REQUEST,
            message="Elicitation not supported",
        )

    async def connect(
        self,
        command: str,
        args: list[str],
        env: dict[str, str] | None = None,
        url: str | None = None,
    ):
        """Connect to an MCP server.

        Args:
            command: Command to run (for stdio servers)
            args: Command arguments (for stdio servers)
            env: Optional environment variables
            url: Server URL (for SSE servers)
        """
        from mcp import ClientSession, StdioServerParameters
        from mcp.client.stdio import stdio_client

        if url:
            # SSE connection - just a placeholder for now
            logger.info("SSE servers not yet implemented")
            self.session = None
            return
        command = shutil.which(command) or command
        # Stdio connection
        params = StdioServerParameters(command=command, args=args, env=env)
        stdio_transport = await self.exit_stack.enter_async_context(stdio_client(params))
        stdio, write = stdio_transport

        # Create a wrapper that matches the expected signature
        async def elicitation_wrapper(context, params):
            if self._elicitation_callback:
                return await self._elicitation_callback(context, params)
            return await self._default_elicitation_callback(context, params)

        async def sampling_wrapper(
            context: RequestContext,
            params: mcp.types.CreateMessageRequestParams,
        ) -> mcp.types.CreateMessageResult | mcp.types.ErrorData:
            if self._sampling_callback:
                return await self._sampling_callback(context, params)
            # If no callback provided, let MCP SDK handle with its default
            import mcp

            return mcp.types.ErrorData(
                code=mcp.types.INVALID_REQUEST,
                message="Sampling not supported",
            )

        session = ClientSession(
            stdio,
            write,
            elicitation_callback=elicitation_wrapper,
            sampling_callback=sampling_wrapper,
        )
        self.session = await self.exit_stack.enter_async_context(session)
        assert self.session
        init_result = await self.session.initialize()
        info = init_result.serverInfo
        # Get available tools
        result = await self.session.list_tools()
        self._available_tools = result.tools
        logger.info("Connected to MCP server %s (%s)", info.name, info.version)
        logger.info("Available tools: %s", len(self._available_tools))

    def get_tools(self) -> list[dict]:
        """Get tools in OpenAI function format."""
        return [
            {"type": "function", "function": mcp_tool_to_fn_schema(tool)}
            for tool in self._available_tools
        ]

    async def list_prompts(self) -> mcp.types.ListPromptsResult:
        """Get available prompts from the server."""
        if not self.session:
            msg = "Not connected to MCP server"
            raise RuntimeError(msg)
        return await self.session.list_prompts()

    async def list_resources(self) -> mcp.types.ListResourcesResult:
        """Get available resources from the server."""
        if not self.session:
            msg = "Not connected to MCP server"
            raise RuntimeError(msg)
        return await self.session.list_resources()

    async def get_prompt(self, name: str) -> mcp.types.GetPromptResult:
        """Get a specific prompt's content."""
        if not self.session:
            msg = "Not connected to MCP server"
            raise RuntimeError(msg)
        return await self.session.get_prompt(name)

    def create_tool_callable(self, tool: MCPTool) -> Callable[..., Awaitable[str]]:
        """Create a properly typed callable from MCP tool schema."""
        from py2openai.functionschema import FunctionSchema

        schema = mcp_tool_to_fn_schema(tool)
        fn_schema = FunctionSchema.from_dict(schema)
        sig = fn_schema.to_python_signature()

        async def tool_callable(**kwargs: Any) -> str:
            """Dynamically generated MCP tool wrapper."""
            return await self.call_tool(tool.name, kwargs)

        # Set proper signature and docstring
        tool_callable.__signature__ = sig  # type: ignore
        tool_callable.__annotations__ = fn_schema.get_annotations()
        tool_callable.__name__ = tool.name
        tool_callable.__doc__ = tool.description or "No description provided."
        return tool_callable

    async def call_tool(self, name: str, arguments: dict | None = None) -> str:
        """Call an MCP tool."""
        from mcp.types import TextContent, TextResourceContents

        if not self.session:
            msg = "Not connected to MCP server"
            raise RuntimeError(msg)

        try:
            result = await self.session.call_tool(name, arguments or {})
            if not isinstance(result.content[0], TextResourceContents | TextContent):
                msg = "Tool returned a non-text response"
                raise TypeError(msg)  # noqa: TRY301
            return result.content[0].text
        except Exception as e:
            msg = f"MCP tool call failed: {e}"
            raise RuntimeError(msg) from e
