"""Middleware to log per-request usage by email."""

from typing import Callable
import time
from starlette.requests import Request
from starlette.responses import Response
from google.cloud import firestore

from agenticlypay.utils.firestore_storage import FirestoreStorage

storage = FirestoreStorage()
db = firestore.Client()


async def usage_middleware(request: Request, call_next: Callable[[Request], Response]) -> Response:
    """Log API usage to Firestore by email."""
    start = time.time()
    response: Response = await call_next(request)
    duration_ms = int((time.time() - start) * 1000)

    # Extract email from request
    email = None
    
    # Try to get email from query parameters first
    if "email" in request.query_params:
        email = request.query_params["email"]
    else:
        # Try to extract account_id from path and look up email
        account_id = None
        path_parts = request.url.path.split("/")
        if "accounts" in path_parts:
            try:
                accounts_idx = path_parts.index("accounts")
                if accounts_idx + 1 < len(path_parts):
                    account_id = path_parts[accounts_idx + 1]
            except (ValueError, IndexError):
                pass
        
        # Also check for developer_account_id in query params
        if not account_id and "developer_account_id" in request.query_params:
            account_id = request.query_params["developer_account_id"]
        
        # Look up email from account_id
        if account_id:
            try:
                mapping_doc = db.collection("account_email_mapping").document(account_id).get()
                if mapping_doc.exists:
                    mapping_data = mapping_doc.to_dict()
                    email = mapping_data.get("email")
            except Exception:
                pass

    # Log usage to Firestore (only if we have an email)
    if email:
        endpoint = f"{request.method} {request.url.path}"
        storage.log_usage(
            email=email,
            endpoint=endpoint,
            status_code=response.status_code,
        )

    return response
