"""Debug server combining ACP protocol with FastAPI for manual testing.

This module provides a debug server that runs both:
1. An ACP (Agent Client Protocol) server for testing client integration
2. A FastAPI web server for manually triggering all notification types

The server allows developers to test ACP client implementations by providing
mock responses and the ability to manually send any notification type from
the ACP schema through a web interface.
"""

from __future__ import annotations

import asyncio
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
import logging
from pathlib import Path
import sys
import threading
from typing import TYPE_CHECKING, Any
import uuid

from fastapi import FastAPI, HTTPException
from fastapi.responses import HTMLResponse
from pydantic import BaseModel, Field
import uvicorn

from acp.agent.protocol import Agent
from acp.schema import (
    AgentMessageChunk,
    AgentPlanUpdate,
    AgentThoughtChunk,
    AuthenticateResponse,
    AvailableCommand,
    AvailableCommandsUpdate,
    ContentToolCallContent,
    CreateTerminalResponse,
    # CurrentModelUpdate,
    CurrentModeUpdate,
    InitializeResponse,
    LoadSessionResponse,
    NewSessionResponse,
    PlanEntry,
    PromptResponse,
    ReadTextFileResponse,
    SessionNotification,
    ToolCallProgress,
    ToolCallStart,
    UserMessageChunk,
    WriteTextFileResponse,
)
from acp.stdio import stdio_streams
from llmling_agent.log import get_logger


if TYPE_CHECKING:
    from collections.abc import AsyncIterator

    from acp import SetSessionModelRequest, SetSessionModeRequest
    from acp.schema import (
        AuthenticateRequest,
        CancelNotification,
        CreateTerminalRequest,
        InitializeRequest,
        LoadSessionRequest,
        NewSessionRequest,
        PromptRequest,
        ReadTextFileRequest,
        SessionUpdate,
        WriteTextFileRequest,
    )


logger = get_logger(__name__)
MOCK_FILE = """\
# Mock file: {path}
# Generated by ACP Debug Server

def example_function():
return "This is mock content for testing"

# Line count: 10 lines
"""


@dataclass
class DebugSession:
    """Debug session data."""

    session_id: str
    created_at: float
    cwd: str


@dataclass
class NotificationRecord:
    """Record of a sent notification."""

    notification_type: str
    session_id: str
    timestamp: float


@dataclass
class DebugState:
    """Type-safe debug server state."""

    sessions: dict[str, DebugSession] = field(default_factory=dict)
    active_session_id: str | None = None
    notifications_sent: list[NotificationRecord] = field(default_factory=list)
    client_connection: Any = None


class MockAgent(Agent):
    """Mock ACP agent for debug server."""

    def __init__(self, debug_state: DebugState) -> None:
        """Initialize with debug state."""
        self.debug_state = debug_state

    async def initialize(self, params: InitializeRequest) -> InitializeResponse:
        """Handle initialize request with mock capabilities."""
        return InitializeResponse.create(
            title="MockAgent",
            name="MockAgent",
            version="1.0",
            protocol_version=1,
            load_session=True,
        )

    async def new_session(self, params: NewSessionRequest) -> NewSessionResponse:
        """Create new debug session."""
        session_id = str(uuid.uuid4())
        session = DebugSession(
            session_id=session_id,
            created_at=asyncio.get_event_loop().time(),
            cwd=params.cwd,
        )
        self.debug_state.sessions[session_id] = session
        self.debug_state.active_session_id = session_id
        return NewSessionResponse(session_id=session_id)

    async def load_session(self, params: LoadSessionRequest) -> LoadSessionResponse:
        """Load existing debug session."""
        if params.session_id in self.debug_state.sessions:
            self.debug_state.active_session_id = params.session_id
            return LoadSessionResponse()
        raise HTTPException(status_code=404, detail="Session not found")

    async def prompt(self, params: PromptRequest) -> PromptResponse:
        """Handle prompt with mock response."""
        logger.info("Received prompt", session_id=params.session_id)
        return PromptResponse(stop_reason="end_turn")

    async def cancel(self, params: CancelNotification) -> None:
        """Handle cancellation."""
        logger.info("Received cancellation request")

    async def authenticate(self, params: AuthenticateRequest) -> AuthenticateResponse | None:
        """Mock authentication - always succeeds."""
        return AuthenticateResponse()

    async def read_text_file(self, params: ReadTextFileRequest) -> ReadTextFileResponse:
        """Mock file reading."""
        mock_content = MOCK_FILE.format(path=params.path)
        return ReadTextFileResponse(content=mock_content)

    async def write_text_file(self, params: WriteTextFileRequest) -> WriteTextFileResponse:
        """Mock file writing."""
        logger.info("Mock write", path=params.path, content_length=len(params.content))
        return WriteTextFileResponse()

    async def create_terminal(self, params: CreateTerminalRequest) -> CreateTerminalResponse:
        """Mock terminal creation."""
        terminal_id = str(uuid.uuid4())
        return CreateTerminalResponse(terminal_id=terminal_id)

    async def set_session_mode(self, params: SetSessionModeRequest) -> None:
        """Mock session mode change."""
        logger.info("Mock session mode change")

    async def set_session_model(self, params: SetSessionModelRequest) -> None:
        """Mock session model change."""
        logger.info("Mock session model change")

    async def ext_notification(self, method: str, params: dict[str, Any]) -> None:
        """Mock extensibility notification."""
        logger.info("Mock ext notification", method=method)

    async def ext_method(self, method: str, params: dict[str, Any]) -> Any:
        """Mock extensibility method."""
        logger.info("Mock ext method", method=method)
        return {"result": "mock response"}


# FastAPI models for web interface
class NotificationRequest(BaseModel):
    """Request to send a notification."""

    session_id: str = Field(description="Target session ID")
    notification_type: str = Field(description="Type of notification to send")
    data: dict[str, Any] = Field(default_factory=dict, description="Notification data")


class DebugStatus(BaseModel):
    """Current debug server status."""

    active_sessions: list[str]
    current_session: str | None
    notifications_sent: int
    acp_connected: bool


@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
    """FastAPI lifespan manager."""
    logger.info("Debug server FastAPI starting up")
    yield
    logger.info("Debug server FastAPI shutting down")


# Create FastAPI app
app = FastAPI(
    title="ACP Debug Server",
    description="Debug interface for Agent Client Protocol testing",
    version="1.0.0",
    lifespan=lifespan,
)


@app.get("/", response_class=HTMLResponse)
async def debug_interface() -> HTMLResponse:
    """Serve debug interface HTML."""
    content = Path(__file__).parent / "debug.html"
    return HTMLResponse(content.read_text())


@app.get("/status")
async def get_status() -> DebugStatus:
    """Get current debug server status."""
    state = _get_debug_state()
    return DebugStatus(
        active_sessions=list(state.sessions.keys()),
        current_session=state.active_session_id,
        notifications_sent=len(state.notifications_sent),
        acp_connected=state.client_connection is not None,
    )


@app.post("/send-notification")
async def send_notification(request: NotificationRequest) -> dict[str, Any]:
    """Send a notification through ACP."""
    state = _get_debug_state()

    if not state.client_connection:
        raise HTTPException(status_code=503, detail="ACP client not connected")

    if request.session_id not in state.sessions:
        raise HTTPException(status_code=404, detail="Session not found")

    try:
        # Create notification based on type
        update = await _create_notification_update(request.notification_type, request.data)

        notification = SessionNotification(session_id=request.session_id, update=update)

        # Send through ACP connection
        await state.client_connection.session_update(notification)

        # Track notification
        record = NotificationRecord(
            notification_type=request.notification_type,
            session_id=request.session_id,
            timestamp=asyncio.get_event_loop().time(),
        )
        state.notifications_sent.append(record)

        logger.info(
            "Sent notification to session",
            notification_type=request.notification_type,
            session_id=request.session_id,
        )
    except Exception as e:
        logger.exception("Failed to send notification")
        raise HTTPException(status_code=500, detail=str(e)) from e
    else:
        return {"success": True, "message": "Notification sent"}


async def _create_notification_update(  # noqa: PLR0911
    notification_type: str, data: dict[str, Any]
) -> SessionUpdate:
    """Create appropriate notification update object."""
    match notification_type:
        case "agent_message":
            return AgentMessageChunk.text(text=data.get("text", "Mock agent message"))
        case "user_message":
            return UserMessageChunk.text(text=data.get("text", "Mock user message"))
        case "agent_thought":
            return AgentThoughtChunk.text(text=data.get("text", "Mock agent thought"))
        case "tool_call_start":
            return ToolCallStart(
                tool_call_id=data.get("tool_call_id", f"tool-{uuid.uuid4()}"),
                title=data.get("title", "Mock Tool Call"),
                status="pending",
                kind=data.get("kind", "other"),
            )
        case "tool_call_progress":
            return ToolCallProgress(
                tool_call_id=data.get("tool_call_id", "tool-123"),
                status=data.get("status", "completed"),
                raw_output=data.get("output"),
                content=[ContentToolCallContent.text(text=data.get("output", "Tool completed"))]
                if data.get("output")
                else None,
            )
        case "plan_update":
            entries = [
                PlanEntry(
                    content="Mock Plan Entry 1",
                    priority="high",
                    status="completed",
                ),
                PlanEntry(
                    content="Mock Plan Entry 2",
                    priority="medium",
                    status="in_progress",
                ),
            ]
            return AgentPlanUpdate(entries=entries)
        case "commands_update":
            commands = [
                AvailableCommand(
                    name="mock-command",
                    description="A mock command for testing",
                ),
            ]
            return AvailableCommandsUpdate(available_commands=commands)
        case "mode_update":
            return CurrentModeUpdate(current_mode_id=data.get("mode_id", "debug"))
        # case "model_update":
        #     return CurrentModelUpdate(current_model_id=data.get("model_id", "None"))
        case _:
            msg = f"Unknown notification type: {notification_type}"
            raise ValueError(msg)


# Global state reference for FastAPI endpoints (unavoidable with FastAPI)
_global_debug_state: DebugState | None = None


def _set_debug_state(state: DebugState) -> None:
    """Set global debug state reference."""
    global _global_debug_state
    _global_debug_state = state


def _get_debug_state() -> DebugState:
    """Get global debug state reference."""
    if _global_debug_state is None:
        msg = "Debug state not initialized"
        raise RuntimeError(msg)
    return _global_debug_state


class ACPDebugServer:
    """Combined ACP and FastAPI debug server."""

    def __init__(self, *, fastapi_port: int = 8000, fastapi_host: str = "127.0.0.1") -> None:
        """Initialize the debug server.

        Args:
            fastapi_port: Port for FastAPI web interface
            fastapi_host: Host for FastAPI web interface
        """
        self.fastapi_port = fastapi_port
        self.fastapi_host = fastapi_host
        self.debug_state = DebugState()
        self.agent = MockAgent(self.debug_state)
        self._running = False
        self._fastapi_thread: threading.Thread | None = None

        # Set global reference for FastAPI endpoints
        _set_debug_state(self.debug_state)

    async def run(self) -> None:
        """Run both ACP server (stdio) and FastAPI server."""
        if self._running:
            msg = "Server already running"
            raise RuntimeError(msg)

        self._running = True
        logger.info("Starting ACP Debug Server")

        try:
            # Start FastAPI server in background thread
            self._start_fastapi()
            # Start ACP server on stdio
            await self._run_acp_server()
        except Exception:
            logger.exception("Error running debug server")
            raise
        finally:
            await self.shutdown()

    def _start_fastapi(self) -> None:
        """Start FastAPI server in a separate thread."""

        def run_fastapi() -> None:
            uvicorn.run(
                app,
                host=self.fastapi_host,
                port=self.fastapi_port,
                log_level="info",
            )

        self._fastapi_thread = threading.Thread(target=run_fastapi, daemon=True)
        self._fastapi_thread.start()
        url = f"http://{self.fastapi_host}:{self.fastapi_port}"
        logger.info("FastAPI debug interface started", url=url)

    async def _run_acp_server(self) -> None:
        """Run ACP server on stdio."""
        from acp import AgentSideConnection

        try:
            logger.info("Starting ACP server on stdio")
            reader, writer = await stdio_streams()

            # Create ACP connection
            def agent_factory(connection: AgentSideConnection) -> Agent:
                return self.agent

            filename = "acp-debug-server.jsonl"
            conn = AgentSideConnection(agent_factory, writer, reader, debug_file=filename)
            # Store connection for FastAPI endpoints
            self.debug_state.client_connection = conn
            logger.info("ACP Debug Server ready - connect your client!")
            url = f"http://{self.fastapi_host}:{self.fastapi_port}"
            logger.info("Web interface", url=url)
            while self._running:  # Keep server running
                await asyncio.sleep(0.1)

        except Exception:
            logger.exception("ACP server error")
            raise

    async def shutdown(self) -> None:
        """Shutdown the debug server."""
        if not self._running:
            msg = "Server is not running"
            raise RuntimeError(msg)

        self._running = False
        logger.info("Shutting down ACP Debug Server")

        # Clean up connection
        if self.debug_state.client_connection:
            try:
                await self.debug_state.client_connection.close()
            except Exception as e:  # noqa: BLE001
                logger.warning("Error closing ACP connection", error=e)
            finally:
                self.debug_state.client_connection = None


async def main() -> None:
    """Entry point for debug server."""
    import argparse

    parser = argparse.ArgumentParser(description="ACP Debug Server")
    parser.add_argument("--port", type=int, default=7777, help="FastAPI port")
    parser.add_argument("--host", default="127.0.0.1", help="FastAPI host")
    parser.add_argument("--log-level", default="info", help="Logging level")

    args = parser.parse_args()

    # Configure logging
    level = getattr(logging, args.log_level.upper())
    logging.basicConfig(
        level=level,
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    )

    server = ACPDebugServer(fastapi_port=args.port, fastapi_host=args.host)

    try:
        await server.run()
    except KeyboardInterrupt:
        logger.info("Debug server interrupted")
    except Exception:
        logger.exception("Debug server error")
        sys.exit(1)


if __name__ == "__main__":
    asyncio.run(main())
