# ruff: noqa: E402

import dotenv

dotenv.load_dotenv()

import traceback
from contextlib import asynccontextmanager
from logging import getLogger
from time import time

import uvicorn
from asgi_correlation_id import CorrelationIdMiddleware
from fastapi import FastAPI, Request, Response, status
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from llama_index.core.agent.workflow import AgentOutput, ReActAgent, ToolCallResult
from llama_index.core.tools import FunctionTool
from llama_index.core.workflow import Context
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor

from blaxel.llamaindex import bl_model, bl_tools

logger = getLogger(__name__)


@asynccontextmanager
async def lifespan(app: FastAPI):
    logger.info("Server running on port 1338")
    yield
    logger.info("Server shutting down")


# MODEL = "gpt-4o-mini"
MODEL = "claude-3-7-sonnet-20250219"
# MODEL = "xai-grok-beta"
# MODEL = "cohere-command-r-plus" # x -> tool call not supported
# MODEL = "gemini-2-5-pro-preview-03-25"
# MODEL = "deepseek-chat"
# MODEL = "mistral-large-latest"

app = FastAPI(lifespan=lifespan)
app.add_middleware(CorrelationIdMiddleware)


@app.middleware("http")
async def log_requests(request: Request, call_next):
    start_time = time()

    response: Response = await call_next(request)

    process_time = (time() - start_time) * 1000
    formatted_process_time = f"{process_time:.2f}"
    rid_header = response.headers.get("X-Request-Id")
    request_id = rid_header or response.headers.get("X-Blaxel-Request-Id")
    logger.info(
        f"{request.method} {request.url.path} {response.status_code} {formatted_process_time}ms rid={request_id}"
    )

    return response


@app.exception_handler(Exception)
async def validation_exception_handler(request: Request, e: Exception):
    logger.error(f"Error on request {request.method} {request.url.path}: {e}")

    # Get the full traceback information
    tb_str = traceback.format_exception(type(e), e, e.__traceback__)
    logger.error(f"Stacktrace: {''.join(tb_str)}")
    return JSONResponse(
        status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
        content=jsonable_encoder({"detail": str(e)}),
    )


async def weather(city: str) -> str:
    """Get the weather in a given city"""
    return f"The weather in {city} is sunny"


@app.post("/")
async def handle_request(request: Request):
    prompt = "You are a helpful assistant that can answer questions and help with tasks."
    tools = await bl_tools(["blaxel-search"]) + [FunctionTool.from_defaults(async_fn=weather)]
    model = await bl_model(MODEL)
    body = await request.json()
    input = body.get("inputs", "")

    agent = ReActAgent(llm=model, tools=tools, system_prompt=prompt)
    context = Context(agent)
    handler = agent.run(input, ctx=context)
    responses: list[AgentOutput] = []
    async for ev in handler.stream_events():
        if isinstance(ev, ToolCallResult):
            logger.info(f"Call {ev.tool_name} with {ev.tool_kwargs}")
        if isinstance(ev, AgentOutput):
            logger.info(ev.response.content)
            responses.append(ev)

    return Response(responses[-1].response.content, media_type="text/plain")


FastAPIInstrumentor.instrument_app(app, exclude_spans=["receive", "send"])

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=1338, log_level="critical", loop="asyncio")
