# SPDX-License-Identifier: Apache-2.0.
# Copyright (c) 2024 - 2025 Waldiez and contributors.
"""Waldiez Message Model."""

from typing import Any, Optional

from pydantic import Field, model_validator
from typing_extensions import Annotated, Literal, Self

from ..common import WaldiezBase, check_function, update_dict

WaldiezChatMessageType = Literal[
    "string", "method", "rag_message_generator", "none"
]
"""Possible types for the message."""

CALLABLE_MESSAGE = "callable_message"
CALLABLE_MESSAGE_ARGS = ["sender", "recipient", "context"]
CALLABLE_MESSAGE_TYPES = (
    ["ConversableAgent", "ConversableAgent", "dict[str, Any]"],
    "Union[dict[str, Any], str]",
)
CALLABLE_MESSAGE_RAG_WITH_CARRYOVER_TYPES = (
    ["RetrieveUserProxyAgent", "ConversableAgent", "dict[str, Any]"],
    "Union[dict[str, Any], str]",
)


class WaldiezChatMessage(WaldiezBase):
    """
    Waldiez Message.

    A generic message with a type and content.

    If the type is not `none`, the content is a string.
    If the type is 'method', the content is the source code of a method.
    If the type is 'last_carryover', the content is a method to return
        the last carryover from the context.
    If the type is 'rag_message_generator', and the sender is a RAG user agent,
        the content will be generated by the `sender.message_generator` method.

    Attributes
    ----------
    type : WaldiezChatMessageType
        The type of the message:
        - string
        - method
        - rag_message_generator
        - none
        If the sender is a RAG user agent,
        and the type is `rag_message_generator`,
        the `{sender}.message_generator` method will be used.
    content : Optional[str]
        The content of the message (string or method).
    context : dict[str, Any]
        Extra context of the message.
    """

    type: Annotated[
        WaldiezChatMessageType,
        Field(
            default="none",
            title="Type",
            description=(
                "The type of the message: "
                "`string`, `method`, "
                "`rag_message_generator` or `none`."
                "If last_carryover, a method to return the context's"
                "last carryover will be used."
                "If the sender is a RAG user agent, "
                "and the type is `message_generator`,"
                "the `sender.message_generator` method will be used."
            ),
        ),
    ] = "none"
    use_carryover: Annotated[
        bool,
        Field(
            default=False,
            title="Use Carryover",
            description="Use the carryover from the context.",
        ),
    ] = False
    content: Annotated[
        Optional[str],
        Field(
            default=None,
            title="Content",
            description="The content of the message (string or method).",
        ),
    ] = None
    context: Annotated[
        dict[str, Any],
        Field(
            default_factory=dict,
            title="Context",
            description="Extra context of the message.",
        ),
    ] = {}

    _content_body: Optional[str] = None

    def is_method(self) -> bool:
        """Check if the message is a method.

        Returns
        -------
        bool
            True if the message is a method, False otherwise.
        """
        return self.type in ("method", "rag_message_generator")

    @property
    def content_body(self) -> Optional[str]:
        """Get the content body."""
        return self._content_body

    @model_validator(mode="after")
    def validate_context_vars(self) -> Self:
        """Try to detect bools nulls and numbers from the context values.

        Returns
        -------
        WaldiezChatMessage
            The validated instance.
        """
        self.context = update_dict(self.context)
        return self

    @model_validator(mode="after")
    def validate_content(self) -> Self:
        """Validate the content (if not a method).

        Returns
        -------
        WaldiezChatMessage
            The validated instance.

        Raises
        ------
        ValueError
            If the content is invalid.
        """
        content: Optional[str] = None
        if self.type == "none":
            content = "None"
        if self.type == "method":
            if not self.content:
                raise ValueError(
                    "The message content is required for the method type"
                )
            content = self.content
        if self.type == "string":
            if not self.content:
                self.content = ""
            if self.use_carryover:
                self.content = get_last_carryover_method_content(
                    text_content=self.content,
                )
            content = self.content
        if self.type == "rag_message_generator":
            if not self.use_carryover:
                content = get_last_carryover_method_content(
                    text_content=self.content or "",
                )
            else:
                content = RAG_METHOD_WITH_CARRYOVER_BODY
                self.content = RAG_METHOD_WITH_CARRYOVER
        self._content_body = content
        return self

    def validate_method(
        self,
        function_name: str,
        function_args: list[str],
    ) -> str:
        """Validate a method.

        Parameters
        ----------
        function_name : str
            The method name.
        function_args : list[str]
            The expected method arguments.

        Returns
        -------
        str
            The validated method body.

        Raises
        ------
        ValueError
            If the validation fails.
        """
        if not self.content:
            raise ValueError(
                "The message content is required for the method type"
            )
        is_valid, error_or_body = check_function(
            code_string=self.content,
            function_name=function_name,
            function_args=function_args,
        )
        if not is_valid:
            raise ValueError(error_or_body)
        return error_or_body


def get_last_carryover_method_content(text_content: str) -> str:
    """Get the last carryover method content.

    Parameters
    ----------
    text_content : str
        Text content before the carryover.

    Returns
    -------
    str
        The last carryover method content.
    """
    method_content = '''
    """Get the message to send using the last carryover.

    Parameters
    ----------
    sender : ConversableAgent
        The source agent.
    recipient : ConversableAgent
        The target agent.
    context : dict[str, Any]
        The context.

    Returns
    -------
    Union[dict[str, Any], str]
        The message to send using the last carryover.
    """
    carryover = context.get("carryover", "")
    if isinstance(carryover, list):
        carryover = carryover[-1]
    if not isinstance(carryover, str):
        if isinstance(carryover, list):
            carryover = carryover[-1]
        elif isinstance(carryover, dict):
            carryover = carryover.get("content", "")
    if not isinstance(carryover, str):
        carryover = ""'''
    if text_content:
        method_content += f"""
    final_message = "{text_content}" + carryover
    return final_message
"""
    else:
        method_content += """
    return carryover
"""
    return method_content


RAG_METHOD_WITH_CARRYOVER_BODY = '''
    """Get the message using the RAG message generator method.

    Parameters
    ----------
    sender : RetrieveUserProxyAgent
        The source agent.
    recipient : ConversableAgent
        The target agent.
    context : dict[str, Any]
        The context.

    Returns
    -------
    Union[dict[str, Any], str]
        The message to send using the last carryover.
    """
    carryover = context.get("carryover", "")
    if isinstance(carryover, list):
        carryover = carryover[-1]
    if not isinstance(carryover, str):
        if isinstance(carryover, list):
            carryover = carryover[-1]
        elif isinstance(carryover, dict):
            carryover = carryover.get("content", "")
    if not isinstance(carryover, str):
        carryover = ""
    message = sender.message_generator(sender, recipient, context)
    if carryover:
        message += carryover
    return message
'''
RAG_METHOD_WITH_CARRYOVER = (
    "def callable_message(\n"
    "    sender: RetrieveUserProxyAgent,\n"
    "    recipient: ConversableAgent,\n"
    "    context: dict[str, Any],\n"
    ") -> Union[dict[str, Any], str]:"
    f"{RAG_METHOD_WITH_CARRYOVER_BODY}"
)
