"""Orchestrates code generation for multiple tools."""

from __future__ import annotations

import contextlib
from dataclasses import dataclass
import time
from typing import TYPE_CHECKING, Any, Literal

from schemez import log
from schemez.code_generation.namespace_callable import NamespaceCallable
from schemez.code_generation.tool_code_generator import ToolCodeGenerator
from schemez.helpers import model_to_python_code


if TYPE_CHECKING:
    from collections.abc import Callable, Sequence

    from fastapi import FastAPI

    from schemez.functionschema import FunctionSchema


logger = log.get_logger(__name__)

ArgsFormat = Literal["model", "explicit"]
OutputType = Literal["stubs", "implementation"]
ClientType = Literal["http", "fastmcp"]


HTTP_IMPORTS = '''"""Generated HTTP tool client code."""

from __future__ import annotations

from pydantic import BaseModel, Field
from typing import Any
import httpx

'''

FASTMCP_IMPORTS = '''"""Generated FastMCP tool client code."""

from __future__ import annotations

from pydantic import BaseModel, Field
from typing import Any
from fastmcp import Client

'''

STUB_IMPORTS = '''"""Generated tool stubs for LLM consumption."""

from __future__ import annotations

from pydantic import BaseModel

'''


def clean_generated_code(code: str) -> str:
    """Clean up generated code by removing unwanted headers and imports."""
    lines = code.split("\n")
    cleaned_lines = []

    for line in lines:
        # Skip common codegen headers and future imports
        if any(
            skip_pattern in line
            for skip_pattern in [
                "from __future__ import annotations",
                "# generated by datamodel-codegen",
                "# filename:",
                "# timestamp:",
            ]
        ):
            continue

        cleaned_lines.append(line)

    return "\n".join(cleaned_lines)


@dataclass
class GeneratedCode:
    """Structured code generation result."""

    models: str
    """Generated Pydantic input models."""

    http_methods: str
    """HTTP client methods using models."""

    clean_methods: str
    """Clean signature methods without models."""

    model_stubs: str
    """Model-based function stubs for LLM consumption."""

    explicit_stubs: str
    """Explicit signature function stubs for LLM consumption."""

    imports: str = ""
    """Common imports."""

    def get_client_code(
        self,
        args_format: ArgsFormat = "explicit",
        output_type: OutputType = "implementation",
        client_type: ClientType = "http",
    ) -> str:
        """Generate client code with specified format and type.

        Args:
            args_format: Argument format - "model" (with Pydantic models) or
                "explicit" (function signatures)
            output_type: Output type - "stubs" (type stubs) or
                "implementation" (working client)
            client_type: Client type - "http" or "fastmcp"

        Returns:
            Formatted client code
        """
        parts = []

        match (args_format, output_type):
            case ("model", "implementation"):
                # Client with models
                if self.imports:
                    parts.append(self.imports)
                if self.models:
                    parts.append(self.models)
                if self.http_methods:
                    parts.append(self.http_methods)

            case ("explicit", "implementation"):
                # Client with explicit signatures
                if self.imports:
                    parts.append(self.imports)
                if self.clean_methods:
                    parts.append(self.clean_methods)

            case ("model", "stubs"):
                # Model-based stubs for LLM consumption
                parts.append(STUB_IMPORTS)
                if self.models:
                    parts.append(self.models)
                if self.model_stubs:
                    parts.append(self.model_stubs)

            case ("explicit", "stubs"):
                # Explicit signature stubs for LLM consumption
                parts.append(STUB_IMPORTS)
                if self.explicit_stubs:
                    parts.append(self.explicit_stubs)

            case _:
                msg = (
                    f"Unknown combination: args_format={args_format}, "
                    f"output_type={output_type}"
                )
                raise ValueError(msg)

        return "\n\n".join(parts)


@dataclass
class ToolsetCodeGenerator:
    """Generates code artifacts for multiple tools."""

    generators: Sequence[ToolCodeGenerator]
    """ToolCodeGenerator instances for each tool."""

    include_docstrings: bool = True
    """Include function docstrings in documentation."""

    @classmethod
    def from_callables(
        cls,
        callables: Sequence[Callable[..., Any]],
        include_docstrings: bool = True,
        exclude_types: list[type] | None = None,
    ) -> ToolsetCodeGenerator:
        """Create a ToolsetCodeGenerator from a sequence of callables."""
        generators = [
            ToolCodeGenerator.from_callable(i, exclude_types=exclude_types)
            for i in callables
        ]
        return cls(generators, include_docstrings)

    @classmethod
    def from_schemas(
        cls,
        schemas: Sequence[FunctionSchema],
        include_docstrings: bool = True,
    ) -> ToolsetCodeGenerator:
        """Create a ToolsetCodeGenerator from schemas only (no execution capability)."""
        generators = [ToolCodeGenerator.from_schema(schema) for schema in schemas]
        return cls(generators, include_docstrings)

    def generate_tool_description(self) -> str:
        """Generate comprehensive tool description with available functions."""
        if not self.generators:
            return "Execute Python code (no tools available)"

        parts = ["Available functions:"]
        for generator in self.generators:
            desc = generator.get_function_definition(self.include_docstrings)
            parts.append(desc)
        return "\n".join(parts)

    def generate_execution_namespace(self) -> dict[str, Any]:
        """Build Python namespace with tool functions and generated models.

        Raises:
            ValueError: If any generator lacks a callable
        """
        namespace: dict[str, Any] = {"__builtins__": __builtins__, "_result": None}

        # Add tool functions - all generators must have callables for execution
        for generator in self.generators:
            namespace[generator.name] = NamespaceCallable.from_generator(generator)

        # Add generated model classes to namespace
        if models_code := self.generate_return_models():
            with contextlib.suppress(Exception):
                exec(models_code, namespace)

        return namespace

    def generate_return_models(self) -> str:
        """Generate Pydantic models for tool return types."""
        model_parts = [c for g in self.generators if (c := g.generate_return_model())]
        return "\n\n".join(model_parts) if model_parts else ""

    def get_fastmcp_client_code(self, server_configs: dict[str, str]) -> str:
        """Generate FastMCP client code with connection management.

        Args:
            server_configs: Dict mapping server names to their transport URLs
                          e.g., {"github": "sse://localhost:8080/github"}

        Returns:
            FastMCP client code as string
        """
        server_config_code = "SERVERS = " + repr(server_configs)

        return f'''{server_config_code}

# Connection management
_clients: dict[str, Client] = {{}}

async def get_mcp_client(server_name: str) -> Client:
    """Get or create FastMCP client for server."""
    if server_name not in _clients:
        if server_name not in SERVERS:
            raise ValueError(f"Unknown MCP server: {{server_name}}")

        transport_url = SERVERS[server_name]
        client = Client(transport_url)
        _clients[server_name] = client

    return _clients[server_name]

async def call_mcp_tool(
    server_name: str, tool_name: str, arguments: dict[str, Any]
) -> Any:
    """Call a FastMCP tool and return the result."""
    client = await get_mcp_client(server_name)

    async with client:
        result = await client.call_tool(tool_name, arguments)

        # FastMCP returns CallToolResult with structured content handling
        if result.data is not None:
            return result.data  # Parsed structured content
        elif result.content:
            # Extract text from content blocks
            text_parts = []
            for content_block in result.content:
                if hasattr(content_block, 'text'):
                    text_parts.append(content_block.text)
            return "\\n".join(text_parts) if text_parts else ""

        return ""

async def close_all_mcp_connections():
    """Close all MCP client connections."""
    for server_name, client in _clients.items():
        try:
            await client.close()
        except Exception as e:
            print(f"Error closing connection to {{server_name}}: {{e}}")
    _clients.clear()
'''

    def generate_structured_code(
        self,
        base_url: str = "http://localhost:8000",
        path_prefix: str = "/tools",
        client_type: ClientType = "http",
    ) -> GeneratedCode:
        """Generate structured code with all components.

        Args:
            base_url: Base URL of the tool server (HTTP) or server name (FastMCP)
            path_prefix: Path prefix for routes (HTTP only)
            client_type: Type of client to generate ("http" or "fastmcp")

        Returns:
            GeneratedCode with all components separated
        """
        start_time = time.time()
        logger.info("Starting structured code generation")

        models_parts: list[str] = []
        http_methods_parts: list[str] = []
        clean_methods_parts: list[str] = []
        model_stubs_parts: list[str] = []
        explicit_stubs_parts: list[str] = []

        for generator in self.generators:
            # Generate input model from schema parameters
            input_class_name = None
            try:
                params_schema = generator.schema.parameters
                if params_schema.get("properties"):
                    words = [word.title() for word in generator.name.split("_")]
                    input_class_name = f"{''.join(words)}Input"

                    model_code = model_to_python_code(
                        params_schema, class_name=input_class_name
                    )
                    if model_code:
                        cleaned_model = clean_generated_code(model_code)
                        models_parts.append(cleaned_model)
            except (ValueError, TypeError, AttributeError):
                input_class_name = None

            # Generate client method with model
            if input_class_name:
                if client_type == "fastmcp":
                    http_method = f'''
async def {generator.name}(input: {input_class_name}) -> Any:
    """{generator.schema.description or f"Call {generator.name} tool"}

    Args:
        input: Function parameters

    Returns:
        Result from the MCP server
    """
    result = await call_mcp_tool("{base_url}", "{generator.name}", input.model_dump())
    return result
'''
                else:
                    http_method = f'''
async def {generator.name}(input: {input_class_name}) -> str:
    """{generator.schema.description or f"Call {generator.name} tool"}

    Args:
        input: Function parameters

    Returns:
        String response from the tool server
    """
    async with httpx.AsyncClient() as client:
        response = await client.get(
            "{base_url}{path_prefix}/{generator.name}",
            params=input.model_dump() if hasattr(input, 'model_dump') else {{}},
            timeout=30.0
        )
        response.raise_for_status()
        return response.text
'''
            elif client_type == "fastmcp":
                http_method = f'''
async def {generator.name}() -> Any:
    """{generator.schema.description or f"Call {generator.name} tool"}

    Returns:
        Result from the MCP server
    """
    result = await call_mcp_tool("{base_url}", "{generator.name}", {{}})
    return result
'''
            else:
                http_method = f'''
async def {generator.name}() -> str:
    """{generator.schema.description or f"Call {generator.name} tool"}

    Returns:
        String response from the tool server
    """
    async with httpx.AsyncClient() as client:
        response = await client.get(
            "{base_url}{path_prefix}/{generator.name}",
            timeout=30.0
        )
        response.raise_for_status()
        return response.text
'''
            http_methods_parts.append(http_method)

            # Generate clean method with natural signature
            signature_str = generator.get_function_signature()
            params_schema = generator.schema.parameters
            param_names = list(params_schema.get("properties", {}).keys())

            if client_type == "fastmcp":
                clean_method = f'''
async def {signature_str}:
    """{generator.schema.description or f"Call {generator.name} tool"}"""
    # Build parameters dict
    params = {{{", ".join(f'"{name}": {name}' for name in param_names)}}}
    # Remove None values
    clean_params = {{k: v for k, v in params.items() if v is not None}}

    result = await call_mcp_tool("{base_url}", "{generator.name}", clean_params)
    return result
'''
            else:
                clean_method = f'''
async def {signature_str}:
    """{generator.schema.description or f"Call {generator.name} tool"}"""
    # Build parameters dict
    params = {{{", ".join(f'"{name}": {name}' for name in param_names)}}}
    # Remove None values
    clean_params = {{k: v for k, v in params.items() if v is not None}}

    async with httpx.AsyncClient() as client:
        response = await client.get(
            "{base_url}{path_prefix}/{generator.name}",
            params=clean_params,
            timeout=30.0
        )
        response.raise_for_status()
        # Parse JSON response and return the result
        result = response.json()
        return result.get("result", response.text)
'''
            clean_methods_parts.append(clean_method)

            # Generate model-based stub
            if input_class_name:
                model_stub = f'''
async def {generator.name}(input: {input_class_name}) -> Any:
    """{generator.schema.description or f"Call {generator.name} tool"}

    Args:
        input: Function parameters

    Returns:
        Function result
    """
    ...
'''
                model_stubs_parts.append(model_stub)

            # Generate explicit signature stub
            explicit_stub = f'''
async def {signature_str}:
    """{generator.schema.description or f"Call {generator.name} tool"}"""
    ...
'''
            explicit_stubs_parts.append(explicit_stub)

        # Choose appropriate imports based on client type
        imports = FASTMCP_IMPORTS if client_type == "fastmcp" else HTTP_IMPORTS

        elapsed = time.time() - start_time
        logger.info("Structured code generation completed in %.2fs", elapsed)

        return GeneratedCode(
            models="\n".join(models_parts),
            http_methods="\n".join(http_methods_parts),
            clean_methods="\n".join(clean_methods_parts),
            model_stubs="\n".join(model_stubs_parts),
            explicit_stubs="\n".join(explicit_stubs_parts),
            imports=imports,
        )

    def generate_code(
        self,
        base_url: str = "http://localhost:8000",
        path_prefix: str = "/tools",
        args_format: ArgsFormat = "explicit",
        output_type: OutputType = "implementation",
    ) -> str:
        """Generate HTTP client code with specified format and type.

        Args:
            base_url: Base URL of the tool server
            path_prefix: Path prefix for routes
            args_format: Format for function arguments
            output_type: Type of output to generate

        Returns:
            Generated client code
        """
        structured_code = self.generate_structured_code(base_url, path_prefix)
        return structured_code.get_client_code(args_format, output_type)

    def generate_mcp_code(
        self,
        server_configs: dict[str, str],
        args_format: ArgsFormat = "explicit",
        output_type: OutputType = "implementation",
    ) -> str:
        """Generate complete FastMCP client code with tools.

        Args:
            server_configs: FastMCP server configurations mapping server names to
                           transport URLs
                           e.g., {"github": "sse://localhost:8080/github"}
            args_format: Format for function arguments
            output_type: Type of output to generate

        Returns:
            Complete FastMCP client code
        """
        # Generate structured code with FastMCP client type
        structured = self.generate_structured_code(client_type="fastmcp")

        # Get FastMCP client code
        mcp_client_code = self.get_fastmcp_client_code(server_configs)

        # Get the appropriate client code based on format/type
        tools_code = structured.get_client_code(
            args_format, output_type, client_type="fastmcp"
        )

        # Combine client and tools code
        return f"""{FASTMCP_IMPORTS}
{mcp_client_code}

{tools_code}"""

    def add_all_routes(self, app: FastAPI, path_prefix: str = "/tools") -> None:
        """Add FastAPI routes for all tools."""
        for generator in self.generators:
            generator.add_route_to_app(app, path_prefix)


if __name__ == "__main__":

    def greet(name: str, greeting: str = "Hello") -> str:
        """Greet someone with a custom message."""
        return f"{greeting}, {name}!"

    def add_numbers(a: int, b: int) -> int:
        """Add two numbers together."""
        return a + b

    generator = ToolsetCodeGenerator.from_callables([greet, add_numbers])

    # Test FastMCP code generation
    server_configs = {"test_server": "stdio://npx -y @test/server"}

    mcp_code = generator.generate_mcp_code(
        server_configs, args_format="explicit", output_type="implementation"
    )
    print("=== FastMCP Code ===")
    print(mcp_code)

    print("\n" + "=" * 50 + "\n")

    # Test HTTP code generation
    http_code = generator.generate_code(
        args_format="explicit", output_type="implementation"
    )
    print("=== HTTP Code ===")
    print(http_code)
