from __future__ import annotations

from datetime import datetime
from functools import wraps
from typing import Any, ClassVar, Dict, Optional, Type, Union
from uuid import UUID, uuid4

from pydantic import BaseModel, Field, validator
from typing_extensions import Self

from kelvin.message.krn import KRN, KRNAssetDataStream
from kelvin.message.msg_type import KMessageType, KMessageTypeData, KMessageTypePrimitive
from kelvin.message.utils import from_rfc3339_timestamp, to_rfc3339_timestamp

# Set to True to fail when parsing unknown message types
FAIL_ON_UNKNOWN_TYPES = False


class Message(BaseModel):
    _MESSAGE_TYPES: ClassVar[Dict[KMessageType, Type["Message"]]] = {}
    _TYPE: ClassVar[Optional[KMessageType]] = None

    id: UUID = Field(default_factory=lambda: uuid4())
    type: Optional[KMessageType] = None
    trace_id: Optional[str] = None
    source: Optional[KRN] = None
    timestamp: datetime = Field(default_factory=lambda: datetime.now().astimezone())
    resource: Optional[KRN] = None

    payload: Any

    class Config:
        underscore_attrs_are_private = True
        json_encoders = {
            datetime: to_rfc3339_timestamp,
            KRN: KRN.encode,
            KMessageType: KMessageType.encode,
        }

    def __init_subclass__(cls) -> None:
        if cls._TYPE:
            Message._MESSAGE_TYPES[cls._TYPE] = cls

    def __new__(cls, **kwargs: Any) -> Message:  # pyright: ignore
        """Initialise message."""

        if cls._TYPE:
            MSG_T = cls
        else:
            msg_type = cls._get_msg_type_from_payload(**kwargs)
            if msg_type is None and FAIL_ON_UNKNOWN_TYPES is True:
                raise ValueError("Missing message type") from None

            MSG_T = Message._MESSAGE_TYPES.get(msg_type, Message)  # type: ignore
        obj = super().__new__(MSG_T)
        return obj

    def __init__(self, **kwargs: Any) -> None:  # pyright: ignore
        """
        Create a kelvin Message.

        Parameters
        ----------
        id : str, optional
            UUID of the message. Optional, auto generated if not provided.
        type : KMessageType
            Message Type
        trace_id : str, optional
            Optional trace id. UUID
        source : KRN, optional
            Identifies the source of the message.
        timestamp : datetime, optional
            Sets a timestamp for the message. If not provided current time is used.
        resource : KRN, optional
            Sets a resource that the message relates to.
        payload : Any
            Payload of the message. Specific for each message sub type.
        """

        new_kwargs = kwargs
        if kwargs.get("data_type"):
            new_kwargs = self._convert_message_v1(**kwargs)
        elif kwargs.get("_"):
            new_kwargs = self._convert_message_v0(**kwargs)

        if self._TYPE:
            new_kwargs["type"] = self._TYPE

        super().__init__(**new_kwargs)

    @validator("timestamp", pre=True, always=True)
    def default_timestamp(cls, v: Union[str, datetime]) -> datetime:
        if isinstance(v, str):
            return from_rfc3339_timestamp(v)

        return v

    @wraps(BaseModel.dict)
    def dict(
        self,
        by_alias: bool = True,
        exclude_none: bool = True,
        exclude_unset: bool = False,
        **kwargs: Any,
    ) -> Dict[str, Any]:
        """Generate a dictionary representation of the model."""

        return super().dict(by_alias=by_alias, exclude_none=exclude_none, exclude_unset=exclude_unset, **kwargs)

    @wraps(BaseModel.json)
    def json(
        self,
        by_alias: bool = True,
        exclude_none: bool = True,
        exclude_unset: bool = False,
        **kwargs: Any,
    ) -> str:
        """Generate a dictionary representation of the model."""

        return super().json(by_alias=by_alias, exclude_none=exclude_none, exclude_unset=exclude_unset, **kwargs)

    def encode(self) -> bytes:
        """Encode message"""
        return bytes(self.json(), "utf-8")

    @classmethod
    def decode(cls, data: bytes) -> Self:
        return cls.parse_raw(data)

    @staticmethod
    def _convert_message_v1(**kwargs: Dict) -> Dict:
        result: Dict[str, Any] = {
            "id": kwargs.get("id", None),
            "timestamp": kwargs.get("timestamp", None),
        }

        asset = kwargs.get("asset_name", None)
        metric = kwargs.get("name", None)
        if asset and metric:
            result["resource"] = KRNAssetDataStream(asset, metric)  # type: ignore

        result["type"] = KMessageTypePrimitive(icd=str(kwargs.get("data_type")))

        source = kwargs.get("source", None)
        if source:
            result["source"] = "krn:wl:" + str(source)

        result["payload"] = kwargs.get("payload")

        return result

    @staticmethod
    def _convert_message_v0(**kwargs: Dict) -> Dict:
        result: Dict[str, Any] = {}

        header = kwargs.pop("_")

        asset = header.get("asset_name", None) or ""
        metric = header.get("name", None) or ""
        # resource should not have empty asset but kelvin-app uses v0 messages with no asset
        result["resource"] = KRNAssetDataStream(asset, metric)

        result["type"] = KMessageTypePrimitive(icd=str(kwargs.get("data_type")))

        source = header.get("source", None)
        if source:
            if isinstance(source, dict):
                source = source.get("node_name", "") + "/" + source.get("workload_name", "")
            result["source"] = "krn:wl:" + source

        timestamp_ns = header.get("time_of_validity", None)
        if timestamp_ns is not None:
            result["timestamp"] = datetime.fromtimestamp(timestamp_ns / 1e9).astimezone()

        id = timestamp_ns = header.get("id", None)
        if id:
            result["id"] = id

        # the remaining kwargs are payload
        result["payload"] = kwargs

        return result

    @staticmethod
    def _get_msg_type_from_payload(**kwargs: Any) -> Optional[KMessageType]:
        # "type" from v2 or "data_type" from v1 or "_.type" from v0
        v2_type = str(kwargs.get("type", ""))
        if v2_type:
            return KMessageType.from_string(v2_type)

        icd = kwargs.get("data_type") or kwargs.get("_", {}).get("type")
        if icd:
            return KMessageTypeData(primitive="object", icd=icd)

        return None
