from collections.abc import Callable

from starlette.background import BackgroundTasks
from starlette.concurrency import iterate_in_threadpool
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
from starlette.types import ASGIApp

from logging_midlleware.background import create_request_log


class LoggingMidlleware(BaseHTTPMiddleware):
    def __init__(self, app: ASGIApp, ignore_fields: list[str] | None = None) -> None:
        super().__init__(app)
        self.ignore_fields = ignore_fields

    async def dispatch(self, request: Request, call_next: Callable) -> Response:
        background_tasks = BackgroundTasks()
        request_body = await request.body()

        try:
            response = await call_next(request)
        except Exception as e:
            background_tasks.add_task(
                create_request_log,
                request=request,
                status_code=500,
                request_body=request_body,
                error=e,
                message='Reqest_failed',
                ignore_fields=self.ignore_fields
            )
            raise e

        if response.status_code >= 400:
            response_body_chunks = [chunk async for chunk in response.body_iterator]
            response.body_iterator = iterate_in_threadpool(iter(response_body_chunks))
            response_body_bytes = b''.join(response_body_chunks)
            message = response_body_bytes.decode()
        else:
            message = 'Request success'

        background_tasks.add_task(
            create_request_log,
            request=request,
            request_body=request_body,
            status_code=response.status_code,
            message=message,
            ignore_fields=self.ignore_fields
        )
        response.background = background_tasks
        return response
