from pydantic import BaseModel, Field, model_validator
from typing import Annotated, Generic, Literal, Self, TypeVar, overload
from uuid import UUID
from maleo.enums.status import (
    DataStatus,
    SimpleDataStatusMixin,
    ListOfDataStatuses,
    FULL_DATA_STATUSES,
)
from maleo.schemas.mixins.filter import convert as convert_filter
from maleo.schemas.mixins.identity import (
    DataIdentifier,
    IdentifierMixin,
    Ids,
    UUIDs,
    Names,
    UUIDOrganizationId,
    UUIDOrganizationIds,
    UUIDPatientIds,
    UUIDUserId,
    UUIDUserIds,
)
from maleo.schemas.mixins.sort import convert as convert_sort
from maleo.schemas.mixins.timestamp import LifecycleTimestamp
from maleo.schemas.operation.enums import ResourceOperationStatusUpdateType
from maleo.schemas.parameter import (
    ReadSingleParameter as BaseReadSingleParameter,
    ReadPaginatedMultipleParameter,
    StatusUpdateParameter as BaseStatusUpdateParameter,
    DeleteSingleParameter as BaseDeleteSingleParameter,
)
from maleo.types.dict import StrToAnyDict
from maleo.types.integer import OptListOfInts
from maleo.types.string import OptStr, ListOfStrs
from maleo.types.uuid import OptUUID, ListOfUUIDs, OptListOfUUIDs
from ..enums.session import IdentifierType, SessionType as SessionTypeEnum
from ..mixins.common import ClientId, ClientIds, ParameterIds
from ..mixins.session import SessionType, Name, SessionIdentifier
from ..types.session import IdentifierValueType
from .document import Document, DocumentMixin, ListOfDocuments, DocumentsMixin


class IndividualSessionInfo(DocumentMixin[Document], Name[OptStr]):
    pass


class CreateIndividualParameter(
    DocumentsMixin[ListOfDocuments],
    UUIDPatientIds[ListOfUUIDs],
    ParameterIds[OptListOfUUIDs],
    Names[ListOfStrs],
):
    def _validate_patient_info(self) -> Self:
        names_len = len(self.names)
        patient_ids_len = len(self.patient_ids)
        documents_len = len(self.documents)

        if not (names_len == patient_ids_len == documents_len):
            raise ValueError(
                f"Mismatched count - names: {names_len}, patient_ids: {patient_ids_len}, documents: {documents_len}"
            )

        for patient_id in self.patient_ids:
            patient_id_prefix = str(patient_id) + "_"

            if not any(name.startswith(patient_id_prefix) for name in self.names):
                raise ValueError(
                    f"Unable to determine session name for patient: '{patient_id}'"
                )

            if not any(
                document.filename.startswith(patient_id_prefix)
                for document in self.documents
            ):
                raise ValueError(
                    f"Unable to determine document for patient: '{patient_id}'"
                )

        return self

    @model_validator(mode="after")
    def validate_patient_info(self) -> Self:
        return self._validate_patient_info()

    @property
    def patient_info(self) -> dict[UUID, IndividualSessionInfo]:
        self._validate_patient_info()
        info: dict[UUID, IndividualSessionInfo] = dict[UUID, IndividualSessionInfo]()
        for patient_id in self.patient_ids:
            patient_id_prefix = str(patient_id) + "_"

            # Define session name
            raw_name = next(
                (name for name in self.names if name.startswith(patient_id_prefix))
            )
            parsed_name = raw_name.removeprefix(patient_id_prefix) or None

            # Define document
            document = next(
                doc
                for doc in self.documents
                if doc.filename.startswith(patient_id_prefix)
            )

            patient_info = IndividualSessionInfo(name=parsed_name, document=document)
            info[patient_id] = patient_info
        return info


class CreateGroupParameter(
    DocumentMixin[Document],
    ClientId[UUID],
    ParameterIds[OptListOfUUIDs],
    Name[str],
):
    pass


AnyCreateParameter = CreateGroupParameter | CreateIndividualParameter


class ReadMultipleParameter(
    ReadPaginatedMultipleParameter,
    ClientIds[OptListOfUUIDs],
    UUIDOrganizationIds[OptListOfUUIDs],
    UUIDUserIds[OptListOfUUIDs],
    UUIDs[OptListOfUUIDs],
    Ids[OptListOfInts],
):
    ids: Annotated[OptListOfInts, Field(None, description="Ids")] = None
    uuids: Annotated[OptListOfUUIDs, Field(None, description="UUIDs")] = None
    user_ids: Annotated[OptListOfUUIDs, Field(None, description="User's IDs")] = None
    organization_ids: Annotated[
        OptListOfUUIDs, Field(None, description="Organization's IDs")
    ] = None
    client_ids: Annotated[OptListOfUUIDs, Field(None, description="Client's Ids")] = (
        None
    )

    @property
    def _query_param_fields(self) -> set[str]:
        return {
            "ids",
            "uuids",
            "statuses",
            "user_ids",
            "organization_ids",
            "client_ids",
            "search",
            "page",
            "limit",
            "use_cache",
        }

    def to_query_params(self) -> StrToAnyDict:
        params = self.model_dump(
            mode="json", include=self._query_param_fields, exclude_none=True
        )
        params["filters"] = convert_filter(self.date_filters)
        params["sorts"] = convert_sort(self.sort_columns)
        params = {k: v for k, v in params.items()}
        return params


class ReadSingleParameter(BaseReadSingleParameter[SessionIdentifier]):
    @classmethod
    def from_identifier(
        cls,
        identifier: SessionIdentifier,
        statuses: ListOfDataStatuses = FULL_DATA_STATUSES,
        use_cache: bool = True,
    ) -> "ReadSingleParameter":
        return cls(identifier=identifier, statuses=statuses, use_cache=use_cache)

    @overload
    @classmethod
    def new(
        cls,
        identifier_type: Literal[IdentifierType.ID],
        identifier_value: int,
        statuses: ListOfDataStatuses = list(FULL_DATA_STATUSES),
        use_cache: bool = True,
    ) -> "ReadSingleParameter": ...
    @overload
    @classmethod
    def new(
        cls,
        identifier_type: Literal[IdentifierType.UUID],
        identifier_value: UUID,
        statuses: ListOfDataStatuses = list(FULL_DATA_STATUSES),
        use_cache: bool = True,
    ) -> "ReadSingleParameter": ...
    @overload
    @classmethod
    def new(
        cls,
        identifier_type: IdentifierType,
        identifier_value: IdentifierValueType,
        statuses: ListOfDataStatuses = list(FULL_DATA_STATUSES),
        use_cache: bool = True,
    ) -> "ReadSingleParameter": ...
    @classmethod
    def new(
        cls,
        identifier_type: IdentifierType,
        identifier_value: IdentifierValueType,
        statuses: ListOfDataStatuses = list(FULL_DATA_STATUSES),
        use_cache: bool = True,
    ) -> "ReadSingleParameter":
        return cls(
            identifier=SessionIdentifier(
                type=identifier_type,
                value=identifier_value,
            ),
            statuses=statuses,
            use_cache=use_cache,
        )

    def to_query_params(self) -> StrToAnyDict:
        return self.model_dump(
            mode="json", include={"statuses", "use_cache"}, exclude_none=True
        )


class FullUpdateData(
    ParameterIds[OptListOfUUIDs],
    Name[str],
):
    pass


class PartialUpdateData(
    ParameterIds[OptListOfUUIDs],
    Name[OptStr],
):
    name: Annotated[
        OptStr, Field(None, description="Session's name", max_length=50)
    ] = None
    parameter_ids: Annotated[
        OptListOfUUIDs, Field(None, description="Parameter's Ids")
    ] = None


UpdateDataT = TypeVar("UpdateDataT", FullUpdateData, PartialUpdateData)


class UpdateDataMixin(BaseModel, Generic[UpdateDataT]):
    data: UpdateDataT = Field(..., description="Update data")


class UpdateParameter(
    UpdateDataMixin[UpdateDataT],
    IdentifierMixin[SessionIdentifier],
    Generic[UpdateDataT],
):
    @overload
    @classmethod
    def new(
        cls,
        identifier_type: Literal[IdentifierType.ID],
        identifier_value: int,
        data: UpdateDataT,
    ) -> "UpdateParameter": ...
    @overload
    @classmethod
    def new(
        cls,
        identifier_type: Literal[IdentifierType.UUID],
        identifier_value: UUID,
        data: UpdateDataT,
    ) -> "UpdateParameter": ...
    @overload
    @classmethod
    def new(
        cls,
        identifier_type: IdentifierType,
        identifier_value: IdentifierValueType,
        data: UpdateDataT,
    ) -> "UpdateParameter": ...
    @classmethod
    def new(
        cls,
        identifier_type: IdentifierType,
        identifier_value: IdentifierValueType,
        data: UpdateDataT,
    ) -> "UpdateParameter":
        return cls(
            identifier=SessionIdentifier(type=identifier_type, value=identifier_value),
            data=data,
        )


class StatusUpdateParameter(
    BaseStatusUpdateParameter[SessionIdentifier],
):
    @overload
    @classmethod
    def new(
        cls,
        identifier_type: Literal[IdentifierType.ID],
        identifier_value: int,
        type: ResourceOperationStatusUpdateType,
    ) -> "StatusUpdateParameter": ...
    @overload
    @classmethod
    def new(
        cls,
        identifier_type: Literal[IdentifierType.UUID],
        identifier_value: UUID,
        type: ResourceOperationStatusUpdateType,
    ) -> "StatusUpdateParameter": ...
    @overload
    @classmethod
    def new(
        cls,
        identifier_type: IdentifierType,
        identifier_value: IdentifierValueType,
        type: ResourceOperationStatusUpdateType,
    ) -> "StatusUpdateParameter": ...
    @classmethod
    def new(
        cls,
        identifier_type: IdentifierType,
        identifier_value: IdentifierValueType,
        type: ResourceOperationStatusUpdateType,
    ) -> "StatusUpdateParameter":
        return cls(
            identifier=SessionIdentifier(type=identifier_type, value=identifier_value),
            type=type,
        )


class DeleteSingleParameter(BaseDeleteSingleParameter[SessionIdentifier]):
    @overload
    @classmethod
    def new(
        cls, identifier_type: Literal[IdentifierType.ID], identifier_value: int
    ) -> "DeleteSingleParameter": ...
    @overload
    @classmethod
    def new(
        cls, identifier_type: Literal[IdentifierType.UUID], identifier_value: UUID
    ) -> "DeleteSingleParameter": ...
    @overload
    @classmethod
    def new(
        cls, identifier_type: IdentifierType, identifier_value: IdentifierValueType
    ) -> "DeleteSingleParameter": ...
    @classmethod
    def new(
        cls, identifier_type: IdentifierType, identifier_value: IdentifierValueType
    ) -> "DeleteSingleParameter":
        return cls(
            identifier=SessionIdentifier(type=identifier_type, value=identifier_value)
        )


class SessionSchema(
    ParameterIds[ListOfUUIDs],
    Name[str],
    ClientId[OptUUID],
    SessionType[SessionTypeEnum],
    UUIDOrganizationId[UUID],
    UUIDUserId[UUID],
    SimpleDataStatusMixin[DataStatus],
    LifecycleTimestamp,
    DataIdentifier,
):
    @model_validator(mode="after")
    def validate_type_client(self) -> Self:
        if self.type is SessionTypeEnum.GROUP:
            if self.client_id is None:
                raise ValueError("Client ID can not be None for group MCU")
        elif self.type is SessionTypeEnum.INDIVIDUAL:
            if self.client_id is not None:
                raise ValueError("Client ID must be None for individual MCU")
        return self
