"""
FastAPI server for EvenAge.

Refactored API with:
- New REST endpoints for jobs, agents, traces, metrics
- WebSocket for real-time events (throttled)
- No OTEL/Prometheus (internal observability only)
- Uses new Agent/AgentRunner architecture
"""

from __future__ import annotations

import asyncio
from contextlib import asynccontextmanager
from datetime import datetime
import logging
import time
from typing import Any

from evenage import __version__ as pkg_version
from evenage.core import (
    AgentRunner,
    BackendFactory,
    run_agent,
)
from evenage.core.config import load_runtime_config, EvenAgeConfig
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


# Request/Response models
class SubmitJobRequest(BaseModel):
    """Request to submit a new job (legacy)."""
    pipeline_name: str
    inputs: dict[str, Any]


class SubmitPromptRequest(BaseModel):
    """Request to submit a new prompt execution."""
    prompt_name: str
    inputs: dict[str, Any]


class ReplayJobRequest(BaseModel):
    """Request to replay a job."""
    reexecute: bool = False


class AgentChatRequest(BaseModel):
    """Request to chat with an agent."""
    message: str
    wait: bool = True
    timeout: int = 30


class AgentInfo(BaseModel):
    """Agent information."""
    name: str
    role: str
    goal: str
    status: str
    tools: list[str]
    last_seen: str | None = None


# Global services
config: EvenAgeConfig
factory: BackendFactory
database: Any
message_bus: Any


@asynccontextmanager
async def lifespan(app: FastAPI):
    """Initialize and cleanup services."""
    global config, factory, database, message_bus

    # Load configuration (prefer project .evenage/config.json)
    config = load_runtime_config()
    factory = BackendFactory(config)

    # Initialize services
    database = factory.create_database_backend()
    message_bus = factory.create_queue_backend()

    # Create database tables
    database.create_tables()

    logger.info("EvenAge API server started")
    logger.info(f"Database: {config.database_url}")
    logger.info(f"Queue: {config.queue_backend}")

    yield

    logger.info("EvenAge API server shutting down")


# Create FastAPI app
app = FastAPI(
    title="EvenAge API",
    description="API for EvenAge distributed agent framework",
    version=pkg_version,
    lifespan=lifespan,
)

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=[
        "http://localhost:5173",
        "http://dashboard:5173",
        "*",  # Configure appropriately for production
    ],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


@app.get("/")
async def root():
    """Root endpoint - API information."""
    return {
        "service": "EvenAge API",
        "version": pkg_version,
        "status": "running",
        "dashboard": "http://localhost:5173",
        "docs": "/docs",
    }


@app.get("/api/health")
async def health_check():
    """Health check endpoint."""
    try:
        # Check message bus
        bus_healthy = await message_bus.health_check()

        # Check database
        db_healthy = True
        try:
            database.get_metrics()
        except Exception:
            db_healthy = False

        if not (bus_healthy and db_healthy):
            raise HTTPException(
                status_code=503,
                detail=f"Services unhealthy: bus={bus_healthy}, db={db_healthy}"
            )

        return {
            "status": "healthy",
            "queue": bus_healthy,
            "database": db_healthy,
            "version": pkg_version,
        }
    except HTTPException:
        raise
    except Exception as e:
        raise HTTPException(status_code=503, detail=str(e))


@app.post("/api/jobs")
async def submit_job(request: SubmitJobRequest):
    """
    Submit a new job for execution.

    Creates job, executes agent, stores results and traces.

    Args:
        request: Job submission request

    Returns:
        Job result with job_id, status, result/error
    """
    try:
        result = await run_agent(
            agent_name=request.pipeline_name,
            inputs=request.inputs,
            env_config=config
        )

        return {
            "job_id": result["job_id"],
            "status": result["status"],
            "result": result.get("result"),
            "error": result.get("error"),
            "duration_ms": result.get("duration_ms"),
            "agent": result.get("agent")
        }

    except Exception as e:
        logger.error(f"Error submitting job: {e}", exc_info=True)
        raise HTTPException(status_code=500, detail=str(e))


@app.get("/api/jobs/{job_id}")
async def get_job(job_id: str):
    """
    Get job status and result.

    Args:
        job_id: Job identifier

    Returns:
        Job information
    """
    try:
        result = database.get_result(job_id)

        if not result:
            raise HTTPException(status_code=404, detail="Job not found")

        return {
            "job_id": job_id,
            "status": result.get("status"),
            "pipeline_name": result.get("pipeline_name"),
            "inputs": result.get("inputs"),
            "outputs": result.get("outputs"),
            "error": result.get("error"),
            "created_at": result.get("created_at"),
            "completed_at": result.get("completed_at")
        }

    except HTTPException:
        raise
    except Exception as e:
        logger.exception("Failed to fetch job")
        raise HTTPException(status_code=500, detail=str(e))


@app.get("/api/jobs")
async def list_jobs(limit: int = 50, offset: int = 0):
    """
    List recent jobs.

    Args:
        limit: Maximum number of jobs to return
        offset: Number of jobs to skip

    Returns:
        List of jobs
    """
    try:
        # DatabaseService currently supports only limit; offset not implemented
        jobs = database.list_jobs(limit=limit)

        return {
            "jobs": [
                {
                    "job_id": job.get("job_id"),
                    "status": job.get("status"),
                    "pipeline_name": job.get("pipeline_name"),
                    "created_at": job.get("created_at"),
                    "completed_at": job.get("completed_at")
                }
                for job in jobs
            ],
            "limit": limit,
            "offset": offset
        }

    except Exception as e:
        logger.exception("Failed to list jobs")
        raise HTTPException(status_code=500, detail=str(e))


@app.post("/api/jobs/{job_id}/replay")
async def replay_job(job_id: str, request: ReplayJobRequest):
    """
    Replay a previous job.

    Args:
        job_id: Job identifier
        request: Replay configuration

    Returns:
        Job result (stored or re-executed)
    """
    try:
        runner = AgentRunner(env_config=config)
        result = runner.replay(job_id, reexecute=request.reexecute)

        return result

    except Exception as e:
        logger.exception("Failed to replay job")
        raise HTTPException(status_code=500, detail=str(e))


# --- New Prompt Endpoints (preferred terminology) ---

@app.post("/api/prompts")
async def submit_prompt(request: SubmitPromptRequest):
    """
    Submit a new prompt execution.

    Creates prompt execution, runs agent, stores results and traces.

    Args:
        request: Prompt submission request

    Returns:
        Prompt result with prompt_id, status, result/error
    """
    try:
        result = await run_agent(
            agent_name=request.prompt_name,
            inputs=request.inputs,
            env_config=config
        )

        # Store using new prompt terminology
        prompt_id = result.get("job_id")  # run_agent still returns job_id
        database.create_prompt(prompt_id, request.prompt_name, request.inputs)

        if result.get("status") == "success":
            database.save_prompt_result(prompt_id, result.get("result", {}), status="completed")
        else:
            database.save_prompt_result(
                prompt_id,
                {"error": result.get("error")},
                status="failed"
            )

        return {
            "prompt_id": prompt_id,
            "status": result["status"],
            "result": result.get("result"),
            "error": result.get("error"),
            "duration_ms": result.get("duration_ms"),
            "agent": result.get("agent")
        }

    except Exception as e:
        logger.error(f"Error submitting prompt: {e}", exc_info=True)
        raise HTTPException(status_code=500, detail=str(e))


@app.get("/api/prompts/{prompt_id}")
async def get_prompt(prompt_id: str):
    """
    Get prompt execution status and result.

    Args:
        prompt_id: Prompt execution identifier

    Returns:
        Prompt information
    """
    try:
        result = database.get_prompt_result(prompt_id)

        if not result:
            raise HTTPException(status_code=404, detail="Prompt execution not found")

        return {
            "prompt_id": prompt_id,
            "status": result.get("status"),
            "prompt_name": result.get("prompt_name"),
            "inputs": result.get("inputs"),
            "outputs": result.get("outputs"),
            "error": result.get("error"),
            "created_at": result.get("created_at"),
            "completed_at": result.get("completed_at")
        }

    except HTTPException:
        raise
    except Exception as e:
        logger.exception("Failed to fetch prompt")
        raise HTTPException(status_code=500, detail=str(e))


@app.get("/api/prompts")
async def list_prompts(limit: int = 50):
    """
    List recent prompt executions.

    Args:
        limit: Maximum number of prompts to return

    Returns:
        List of prompt executions
    """
    try:
        prompts = database.list_prompts(limit=limit)

        return {
            "prompts": [
                {
                    "prompt_id": prompt.get("prompt_id"),
                    "status": prompt.get("status"),
                    "prompt_name": prompt.get("prompt_name"),
                    "created_at": prompt.get("created_at"),
                    "completed_at": prompt.get("completed_at")
                }
                for prompt in prompts
            ],
            "limit": limit
        }

    except Exception as e:
        logger.exception("Failed to list prompts")
        raise HTTPException(status_code=500, detail=str(e))


# --- Legacy Job Endpoints (backward compatibility) ---

@app.get("/api/agents")
async def list_agents():
    """
    List registered agents.

    Returns:
        List of agents with metadata
    """
    try:
        agents = await message_bus.get_registered_agents()

        return {
            "agents": [
                AgentInfo(
                    name=name,
                    role=meta.get("role", "unknown"),
                    goal=meta.get("goal", ""),
                    status=meta.get("status", "unknown"),
                    tools=meta.get("tools", []),
                    last_seen=meta.get("last_seen")
                ).dict()
                for name, meta in agents.items()
            ]
        }

    except Exception as e:
        logger.exception("Failed to list agents")
        raise HTTPException(status_code=500, detail=str(e))


@app.post("/api/agents/{agent_name}/chat")
async def chat_with_agent(agent_name: str, request: AgentChatRequest):
    """
    Send a message to an agent.

    Args:
        agent_name: Agent name
        request: Chat request with message

    Returns:
        Agent response or task_id
    """
    try:
        # Publish task
        task_id = await message_bus.publish_task(
            agent_name,
            {"message": request.message}
        )

        if request.wait:
            # Wait for response
            response = await message_bus.wait_for_response(
                task_id,
                timeout=request.timeout
            )

            if response:
                return {
                    "task_id": task_id,
                    "status": "completed",
                    "response": response
                }
            return {
                "task_id": task_id,
                "status": "timeout",
                "message": "Agent did not respond in time"
            }
        # Return task_id immediately
        return {
            "task_id": task_id,
            "status": "submitted"
        }

    except Exception as e:
        logger.exception("Failed to chat with agent")
        raise HTTPException(status_code=500, detail=str(e))


@app.get("/api/traces")
async def list_traces(
    job_id: str | None = None,
    agent_name: str | None = None,
    limit: int = 100,
):
    """List trace events with optional filters and include payload."""
    try:
        if job_id:
            traces = database.get_job_traces(job_id)
        else:
            traces = database.list_traces(agent=agent_name, limit=limit)

        return {"traces": traces}

    except Exception as e:
        logger.exception("Failed to list traces")
        raise HTTPException(status_code=500, detail=str(e))


@app.get("/api/traces/{trace_id}")
async def get_trace(trace_id: str):
    """Get a single trace by id with full payload."""
    try:
        trace = database.get_trace(trace_id)
        if not trace:
            raise HTTPException(status_code=404, detail="Trace not found")
        return trace
    except HTTPException:
        raise
    except Exception as e:
        logger.exception("Failed to get trace")
        raise HTTPException(status_code=500, detail=str(e))


@app.get("/api/metrics")
async def get_metrics():
    """
    Get system metrics.

    Returns:
        Aggregated metrics from traces and jobs
    """
    try:
        metrics = database.get_metrics()

        return {
            "total_jobs": metrics.get("total_jobs", 0),
            "avg_latency_ms": metrics.get("avg_latency_ms", 0),
            "active_agents_count": metrics.get("active_agents_count", 0),
            "cache_hits": metrics.get("cache_hits", 0),
            "tool_calls": metrics.get("tool_calls", 0),
            "timestamp": datetime.utcnow().isoformat()
        }

    except Exception as e:
        logger.exception("Failed to get metrics")
        raise HTTPException(status_code=500, detail=str(e))


# WebSocket for real-time events (throttled)
class EventThrottler:
    """Throttle WebSocket events to avoid overwhelming clients."""

    def __init__(self, max_per_second: int = 10):
        self.max_per_second = max_per_second
        self.last_sent = 0.0
        self.buffer: list[dict] = []

    def add(self, event: dict) -> dict | None:
        """Add event, return it if should be sent immediately."""
        now = time.time()
        elapsed = now - self.last_sent

        if elapsed >= (1.0 / self.max_per_second):
            self.last_sent = now
            return event
        # Buffer for batching
        self.buffer.append(event)
        if len(self.buffer) >= 5:  # Batch size
            batch = self.buffer[:]
            self.buffer.clear()
            self.last_sent = now
            return {"type": "batch", "events": batch}

        return None


@app.websocket("/ws/events")
async def websocket_events(websocket: WebSocket):
    """
    WebSocket endpoint for real-time events.

    Streams job submissions, completions, trace events with throttling.
    """
    await websocket.accept()
    throttler = EventThrottler(max_per_second=10)

    try:
        while True:
            # Poll for recent traces (last 5 seconds)
            traces = database.list_traces(limit=20)
            recent: list[dict] = []
            now = datetime.utcnow()
            for t in traces:
                ts = t.get("timestamp")
                if not ts:
                    continue
                try:
                    # Support ISO strings
                    ts_dt = datetime.fromisoformat(ts.replace("Z", "+00:00")).replace(tzinfo=None)
                    if (now - ts_dt).total_seconds() < 5:
                        recent.append(t)
                except Exception:
                    continue

            for trace in recent:
                event = {
                    "type": "trace",
                    "data": {
                        "job_id": trace.get("job_id"),
                        "agent_name": trace.get("agent_name"),
                        "event_type": trace.get("event_type"),
                        "timestamp": trace.get("timestamp")
                    }
                }

                to_send = throttler.add(event)
                if to_send:
                    await websocket.send_json(to_send)

            # Sleep before next poll
            await asyncio.sleep(0.5)

    except WebSocketDisconnect:
        logger.info("WebSocket client disconnected")
    except Exception as e:
        logger.error(f"WebSocket error: {e}")


if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)
