from Crypto.PublicKey.RSA import RsaKey
from datetime import datetime, timezone
from fastapi import FastAPI, Request, Response, status
from fastapi.responses import JSONResponse
from starlette.middleware.base import (
    BaseHTTPMiddleware,
    RequestResponseEndpoint,
)
from starlette.types import ASGIApp
from typing import Optional
from uuid import uuid4
from maleo_soma.enums.logging import LogLevel
from maleo_soma.enums.operation import (
    OperationOrigin,
    OperationLayer,
    OperationTarget,
    SystemOperationType,
)
from maleo_soma.exceptions import Error, TooManyRequests, InternalServerError
from maleo_soma.schemas.authentication import Authentication
from maleo_soma.schemas.operation.context import generate_operation_context
from maleo_soma.schemas.operation.request import (
    CreateFailedRequestOperationSchema,
    ReadFailedRequestOperationSchema,
    UpdateFailedRequestOperationSchema,
    DeleteFailedRequestOperationSchema,
    CreateSuccessfulRequestOperationSchema,
    ReadSuccessfulRequestOperationSchema,
    UpdateSuccessfulRequestOperationSchema,
    DeleteSuccessfulRequestOperationSchema,
)
from maleo_soma.schemas.operation.resource.action import (
    extract_resource_operation_action,
    CreateResourceOperationAction,
    ReadResourceOperationAction,
    UpdateResourceOperationAction,
    DeleteResourceOperationAction,
)
from maleo_soma.schemas.operation.system import SuccessfulSystemOperationSchema
from maleo_soma.schemas.operation.system.action import SystemOperationActionSchema
from maleo_soma.schemas.operation.timestamp import OperationTimestamp
from maleo_soma.schemas.response import InternalServerErrorResponseSchema
from maleo_soma.schemas.service import ServiceContext
from maleo_soma.schemas.request import RequestContext
from maleo_soma.schemas.response import ResponseContext, TooManyRequestsResponseSchema
from maleo_soma.types.base import OptionalUUID
from maleo_soma.utils.logging import MiddlewareLogger
from maleo_soma.utils.name import get_fully_qualified_name
from .rate_limit import RateLimiter
from .response_builder import ResponseBuilder


class BaseMiddleware(BaseHTTPMiddleware):
    """Base Middleware for Maleo"""

    key = "base_middleware"
    name = "Base Middleware"

    def __init__(
        self,
        app: ASGIApp,
        logger: MiddlewareLogger,
        private_key: RsaKey,
        rate_limiter: RateLimiter,
        response_builder: ResponseBuilder,
        service_context: Optional[ServiceContext] = None,
        operation_id: OptionalUUID = None,
    ) -> None:
        super().__init__(app, None)
        self._logger = logger
        self._private_key = private_key

        self._service_context = (
            service_context
            if service_context is not None
            else ServiceContext.from_env()
        )
        operation_id = operation_id if operation_id is not None else uuid4()

        self.rate_limiter = rate_limiter

        self._response_builder = response_builder

        operation_context = generate_operation_context(
            origin=OperationOrigin.SERVICE,
            layer=OperationLayer.MIDDLEWARE,
            layer_details={
                "identifier": {
                    "key": self.key,
                    "name": self.name,
                }
            },
            target=OperationTarget.INTERNAL,
            target_details={"fully_qualified_name": get_fully_qualified_name()},
        )

        operation_action = SystemOperationActionSchema(
            type=SystemOperationType.INITIALIZATION,
            details={
                "type": "class_initialization",
                "class_key": self.key,
                "class_name": self.name,
            },
        )

        SuccessfulSystemOperationSchema(
            service_context=self._service_context,
            id=operation_id,
            context=operation_context,
            timestamp=OperationTimestamp(
                executed_at=datetime.now(tz=timezone.utc),
                completed_at=None,
                duration=0,
            ),
            summary=f"Successfully initialized {self.name}",
            request_context=None,
            authentication=None,
            action=operation_action,
            result=None,
        ).log(logger=self._logger, level=LogLevel.INFO)

    async def dispatch(
        self, request: Request, call_next: RequestResponseEndpoint
    ) -> Response:
        # Get all necessary states
        try:
            # Get Operation Id
            operation_id = request.state.operation_id

            # Get Request Context
            request_context = RequestContext.from_request(request=request)

            # Get Authentication
            authentication = Authentication.from_request(
                request=request, from_state=True
            )

            # Get Operation action
            resource_operation_action = extract_resource_operation_action(
                request=request
            )

        except Exception:
            response = JSONResponse(
                content=InternalServerErrorResponseSchema().model_dump(),
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            )
            return response

        operation_context = generate_operation_context(
            origin=OperationOrigin.SERVICE,
            layer=OperationLayer.MIDDLEWARE,
            layer_details={"type": "base"},
            target=OperationTarget.INTERNAL,
            target_details={"fully_qualified_name": get_fully_qualified_name()},
        )

        executed_at = datetime.now(tz=timezone.utc)
        error = None

        try:
            user_id = (
                authentication.credentials.token.payload.u_i
                if authentication.credentials.token is not None
                else None
            )
            organization_id = (
                authentication.credentials.token.payload.o_i
                if authentication.credentials.token is not None
                else None
            )
            if self.rate_limiter.is_rate_limited(
                ip_address=request_context.ip_address,
                user_id=user_id,
                organization_id=organization_id,
            ):
                completed_at = datetime.now(tz=timezone.utc)
                raise TooManyRequests(
                    service_context=self._service_context,
                    operation_id=operation_id,
                    operation_context=operation_context,
                    operation_timestamp=OperationTimestamp(
                        executed_at=executed_at,
                        completed_at=completed_at,
                        duration=(completed_at - executed_at).total_seconds(),
                    ),
                    operation_summary="Too many requests",
                    request_context=request_context,
                    authentication=authentication,
                    operation_action=resource_operation_action,
                )
            response = await call_next(request)
        except TooManyRequests as tmr:
            error = tmr
            response = JSONResponse(
                content=TooManyRequestsResponseSchema().model_dump(),
                status_code=status.HTTP_429_TOO_MANY_REQUESTS,
            )
        except Exception as e:
            completed_at = datetime.now(tz=timezone.utc)
            error = InternalServerError(
                service_context=self._service_context,
                operation_id=operation_id,
                operation_context=operation_context,
                operation_timestamp=OperationTimestamp(
                    executed_at=executed_at,
                    completed_at=completed_at,
                    duration=(completed_at - executed_at).total_seconds(),
                ),
                operation_summary="Failed processing request",
                operation_action=resource_operation_action,
                request_context=request_context,
                authentication=authentication,
                details=e,
            )
            response = JSONResponse(
                content=InternalServerErrorResponseSchema().model_dump(),
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            )

        completed_at = datetime.now(tz=timezone.utc)

        duration = (completed_at - executed_at).total_seconds()
        response = self._response_builder.add_headers(
            operation_id=operation_id,
            request_context=request_context,
            response=response,
            responded_at=completed_at,
            process_time=duration,
        )

        operation_timestamp = OperationTimestamp(
            executed_at=executed_at, completed_at=completed_at, duration=duration
        )

        response_context = ResponseContext(
            status_code=response.status_code,
            media_type=response.media_type,
            headers=response.headers.items(),
            body=response.body,
        )

        if response.status_code >= 200 and response.status_code < 300:
            if isinstance(resource_operation_action, CreateResourceOperationAction):
                CreateSuccessfulRequestOperationSchema(
                    service_context=self._service_context,
                    id=operation_id,
                    context=operation_context,
                    timestamp=operation_timestamp,
                    request_context=request_context,
                    authentication=authentication,
                    action=resource_operation_action,
                    response_context=response_context,
                ).log(logger=self._logger, level=LogLevel.INFO)
            elif isinstance(resource_operation_action, ReadResourceOperationAction):
                ReadSuccessfulRequestOperationSchema(
                    service_context=self._service_context,
                    id=operation_id,
                    context=operation_context,
                    timestamp=operation_timestamp,
                    request_context=request_context,
                    authentication=authentication,
                    action=resource_operation_action,
                    response_context=response_context,
                ).log(logger=self._logger, level=LogLevel.INFO)
            elif isinstance(resource_operation_action, UpdateResourceOperationAction):
                UpdateSuccessfulRequestOperationSchema(
                    service_context=self._service_context,
                    id=operation_id,
                    context=operation_context,
                    timestamp=operation_timestamp,
                    request_context=request_context,
                    authentication=authentication,
                    action=resource_operation_action,
                    response_context=response_context,
                ).log(logger=self._logger, level=LogLevel.INFO)
            elif isinstance(resource_operation_action, DeleteResourceOperationAction):
                DeleteSuccessfulRequestOperationSchema(
                    service_context=self._service_context,
                    id=operation_id,
                    context=operation_context,
                    timestamp=operation_timestamp,
                    request_context=request_context,
                    authentication=authentication,
                    action=resource_operation_action,
                    response_context=response_context,
                ).log(logger=self._logger, level=LogLevel.INFO)

        elif response.status_code >= 400:
            if isinstance(resource_operation_action, CreateResourceOperationAction):
                CreateFailedRequestOperationSchema(
                    service_context=self._service_context,
                    id=operation_id,
                    context=operation_context,
                    timestamp=operation_timestamp,
                    error=error.schema if isinstance(error, Error) else None,
                    request_context=request_context,
                    authentication=authentication,
                    action=resource_operation_action,
                    response_context=response_context,
                ).log(logger=self._logger, level=LogLevel.ERROR)
            elif isinstance(resource_operation_action, ReadResourceOperationAction):
                ReadFailedRequestOperationSchema(
                    service_context=self._service_context,
                    id=operation_id,
                    context=operation_context,
                    timestamp=operation_timestamp,
                    error=error.schema if isinstance(error, Error) else None,
                    request_context=request_context,
                    authentication=authentication,
                    action=resource_operation_action,
                    response_context=response_context,
                ).log(logger=self._logger, level=LogLevel.ERROR)
            elif isinstance(resource_operation_action, UpdateResourceOperationAction):
                UpdateFailedRequestOperationSchema(
                    service_context=self._service_context,
                    id=operation_id,
                    context=operation_context,
                    timestamp=operation_timestamp,
                    error=error.schema if isinstance(error, Error) else None,
                    request_context=request_context,
                    authentication=authentication,
                    action=resource_operation_action,
                    response_context=response_context,
                ).log(logger=self._logger, level=LogLevel.ERROR)
            elif isinstance(resource_operation_action, DeleteResourceOperationAction):
                DeleteFailedRequestOperationSchema(
                    service_context=self._service_context,
                    id=operation_id,
                    context=operation_context,
                    timestamp=operation_timestamp,
                    error=error.schema if isinstance(error, Error) else None,
                    request_context=request_context,
                    authentication=authentication,
                    action=resource_operation_action,
                    response_context=response_context,
                ).log(logger=self._logger, level=LogLevel.ERROR)

        # Call and return response
        return response


def add_base_middleware(
    app: FastAPI,
    *,
    logger: MiddlewareLogger,
    private_key: RsaKey,
    rate_limiter: RateLimiter,
    response_builder: ResponseBuilder,
    service_context: Optional[ServiceContext] = None,
    operation_id: OptionalUUID = None,
) -> None:
    """
    Add Base middleware to the FastAPI application.

    Args:
        app:FastAPI application instance
        keys:RSA keys for signing and token generation
        logger:Middleware logger instance
        maleo_soma:Client manager for soma services
        allow_origins:CORS allowed origins
        allow_methods:CORS allowed methods
        allow_headers:CORS allowed headers
        allow_credentials:CORS allow credentials flag
        limit:Request count limit per window
        window:Time window for rate limiting (seconds)
        cleanup_interval:Cleanup interval for old IP data (seconds)
        ip_timeout:IP timeout after last activity (seconds)

    Example:
        ```python
        add_base_middleware(
            app=app,
            keys=rsa_keys,
            logger=middleware_logger,
            maleo_soma=client_manager,
            limit=10,
            window=1,
            cleanup_interval=60,
            ip_timeout=300
        )
        ```
    """
    app.add_middleware(
        BaseMiddleware,
        logger=logger,
        private_key=private_key,
        rate_limiter=rate_limiter,
        response_builder=response_builder,
        service_context=service_context,
        operation_id=operation_id,
    )
