from typing import Dict, List, Optional, Union

from galtea.application.services.session_service import SessionService
from galtea.domain.models.inference_result import InferenceResult, InferenceResultBase, InferenceResultUpdate
from galtea.infrastructure.clients.http_client import Client
from galtea.utils.string import build_query_params, is_valid_id
from galtea.utils.undefined import UNDEFINED, UndefinedType, filter_undefined


class InferenceResultService:
    """
    Service for managing Inference Results.
    An Inference Result is a single turn in a conversation between a user and the AI, including the input and output.
    """

    def __init__(self, client: Client, session_service: SessionService):
        """Initialize the InferenceResultService with the provided HTTP client.

        Args:
            client (Client): The HTTP client for making API requests.
        """
        self.__client: Client = client
        self.__session_service: SessionService = session_service

    def create(
        self,
        session_id: str,
        input: str,
        output: Optional[str] = None,
        retrieval_context: Optional[str] = None,
        latency: Optional[float] = None,
        usage_info: Optional[Dict[str, int]] = None,
        cost_info: Optional[Dict[str, float]] = None,
        conversation_simulator_version: Optional[str] = None,
    ) -> InferenceResult:
        """Create a new inference result log in a session.

        Args:
            session_id (str): The session ID to log the inference result to
            input (str): The input text/prompt
            output (str, optional): The generated output/response
            retrieval_context (Optional[str]): Context retrieved for RAG systems
            latency (float, optional): Latency in milliseconds since the model was called until
                the response was received.
            usage_info (dict[str, int], optional): Information about token usage during the
                model call.
                Possible keys include:
                - 'input_tokens': Number of input tokens sent to the model.
                - 'output_tokens': Number of output tokens generated by the model.
                - 'cache_read_input_tokens': Number of input tokens read from the cache.
            cost_info (dict[str, float], optional): Information about the cost per token during
                the model call.
                Possible keys include:
                - 'cost_per_input_token': Cost per input token sent to the model.
                - 'cost_per_output_token': Cost per output token generated by the model.
                - 'cost_per_cache_read_input_token': Cost per input token read from the cache.
            conversation_simulator_version (str, optional): The version of Galtea's conversation simulator
                used to generate the user message (input).
                This should only be provided if the input was generated using the conversation_simulator_service.

        Returns:
            InferenceResult: The created inference result log object
        """
        # Construct InferenceResultBase payload
        inference_result_base: InferenceResultBase = InferenceResultBase(
            session_id=session_id,
            actual_input=input,
            actual_output=output,
            retrieval_context=retrieval_context,
            latency=latency,
            **(cost_info or {}),
            **(usage_info or {}),
            conversation_simulator_version=conversation_simulator_version,
        )

        # Validate the payload
        inference_result_base.model_validate(inference_result_base.model_dump())

        # Send the request - could be to /sessions/{session_id}/inference_results or /inference_results
        response = self.__client.post(
            "inferenceResults",
            json=inference_result_base.model_dump(by_alias=True, exclude_none=True),
        )

        return InferenceResult(**response.json())

    def create_batch(
        self,
        session_id: str,
        conversation_turns: List[Dict[str, str]],
        conversation_simulator_version: Optional[str] = None,
    ) -> List[InferenceResult]:
        """Create a batch of inference result in a session in a single http call.

        Args:
            session_id (str): The session ID to log the inference results to
            conversation_turns (list[dict[str, str]], optional): Historic of the past chat
                conversation turns from the user and the model. Each turn is a dictionary with
                "role" and "content" keys following the standard conversation format.
                For instance:
                - [
                    {"role": "user", "content": "what is the capital of France?"},
                    {"role": "assistant", "content": "Paris"},
                    {"role": "user", "content": "what is the population of that city?"},
                    {"role": "assistant", "content": "2M"}
                ]
            conversation_simulator_version (str, optional): The version of Galtea's conversation simulator
                used to generate the user messages (inputs).
                This should only be provided if using the conversation_simulator_service to generate user messages.

        Returns:
            List[InferenceResult]: List of created inference result objects
        """
        if not is_valid_id(session_id):
            raise ValueError("A valid session ID must be provided.")

        if not conversation_turns or not isinstance(conversation_turns, list):
            raise ValueError("conversation_turns must be a non-empty list of dictionaries.")

        response = self.__client.post(
            "inferenceResults/batch",
            json={
                "conversationTurns": conversation_turns,
                "sessionId": session_id,
                "conversationSimulatorVersion": conversation_simulator_version,
            },
        )
        inference_results = [InferenceResult(**result) for result in response.json()]
        return inference_results

    def list(
        self,
        session_id: Union[str, list[str]],
        sort_by_created_at: Optional[str] = None,
        offset: Optional[int] = None,
        limit: Optional[int] = None,
    ) -> List[InferenceResult]:
        """List inference result logs for a session.

        Args:
            session_id (str | list[str]): The session ID or list of session IDs to get inference results from.
            sort_by_created_at (str, optional): Sort by created at. Valid values are 'asc' and 'desc'.
            offset (int, optional): Offset for pagination.
                This refers to the number of items to skip before starting to collect the result set.
                The default value is 0.
            limit (int, optional): Limit for pagination.
                This refers to the maximum number of items to collect in the result set.

        Returns:
            List[InferenceResult]: List of inference result log objects
        """
        # 1. Validate IDs filter parameters
        session_ids = [session_id] if isinstance(session_id, str) else session_id
        if not session_ids or not all(is_valid_id(session_id) for session_id in session_ids):
            raise ValueError("A valid session ID must be provided.")

        # 2. Validate sort parameters
        if sort_by_created_at is not None and sort_by_created_at not in ["asc", "desc"]:
            raise ValueError("Sort by created at must be 'asc' or 'desc'.")

        query_params = build_query_params(
            sessionIds=session_ids,
            offset=offset,
            limit=limit,
            sort=["createdAt", sort_by_created_at] if sort_by_created_at else None,
        )
        response = self.__client.get(f"inferenceResults?{query_params}")
        inference_results = [InferenceResult(**result) for result in response.json()]

        if not inference_results:
            for session_id in session_ids:
                session = self.__session_service.get(session_id)
                if not session:
                    raise ValueError(f"Session with ID {session_id} does not exist.")

        return inference_results

    def update(
        self,
        inference_result_id: str,
        actual_output: Union[str, None, UndefinedType] = UNDEFINED,
        actual_input: Union[str, None, UndefinedType] = UNDEFINED,
        retrieval_context: Union[str, None, UndefinedType] = UNDEFINED,
        latency: Union[float, None, UndefinedType] = UNDEFINED,
        input_tokens: Union[int, None, UndefinedType] = UNDEFINED,
        output_tokens: Union[int, None, UndefinedType] = UNDEFINED,
        cache_read_input_tokens: Union[int, None, UndefinedType] = UNDEFINED,
        tokens: Union[int, None, UndefinedType] = UNDEFINED,
        cost: Union[float, None, UndefinedType] = UNDEFINED,
        cost_per_input_token: Union[float, None, UndefinedType] = UNDEFINED,
        cost_per_output_token: Union[float, None, UndefinedType] = UNDEFINED,
        cost_per_cache_read_input_token: Union[float, None, UndefinedType] = UNDEFINED,
        conversation_simulator_version: Union[str, None, UndefinedType] = UNDEFINED,
        created_at: Union[str, None, UndefinedType] = UNDEFINED,
        deleted_at: Union[str, None, UndefinedType] = UNDEFINED,
        index: Union[int, None, UndefinedType] = UNDEFINED,
        session_id: Union[str, None, UndefinedType] = UNDEFINED,
    ) -> InferenceResult:
        """
        Update an existing inference result with agent output and metadata.

        Args:
            inference_result_id (str): The ID of the inference result to update.
            actual_output (str | None | Undefined): The generated output or response
                from the AI model for this turn.
            actual_input (str | None | Undefined): The input text or prompt for
                the inference result.
            retrieval_context (str | None | Undefined): The context retrieved by
                a RAG system, if applicable.
            latency (float | None | Undefined): The time in milliseconds from
                request to response.
            input_tokens (int | None | Undefined): Number of input tokens sent
                to the model.
            output_tokens (int | None | Undefined): Number of output tokens
                generated by the model.
            cache_read_input_tokens (int | None | Undefined): Number of input
                tokens read from the cache.
            tokens (int | None | Undefined): Total tokens used in the model call.
            cost (float | None | Undefined): The total cost associated with the
                model call.
            cost_per_input_token (float | None | Undefined): Cost per input token
                sent to the model.
            cost_per_output_token (float | None | Undefined): Cost per output
                token generated by the model.
            cost_per_cache_read_input_token (float | None | Undefined): Cost per
                input token read from the cache.
            conversation_simulator_version (str | None | Undefined): The version
                of Galtea's conversation simulator used to generate the user
                message.
            created_at (str | None | Undefined): Timestamp when the inference
                result was created.
            deleted_at (str | None | Undefined): Timestamp when the inference
                result was deleted.
            index (int | None | Undefined): The position of this inference result
                within the session.
            session_id (str | None | Undefined): The session to which this
                inference result belongs.

        Returns:
            InferenceResult: The updated inference result object.

        Note:
            The creditsUsed field cannot be modified through this method as it's
            controlled by the API to prevent unauthorized credit manipulation.

            Pass UNDEFINED (default) to exclude a field from the update.
            Pass None to explicitly set a field to null.
            Pass a value to update the field to that value.
        """
        if not is_valid_id(inference_result_id):
            raise ValueError("A valid inference result ID must be provided.")

        # Collect all parameters, excluding self and inference_result_id
        params = {
            "actual_input": actual_input,
            "actual_output": actual_output,
            "cache_read_input_tokens": cache_read_input_tokens,
            "conversation_simulator_version": conversation_simulator_version,
            "cost": cost,
            "cost_per_cache_read_input_token": cost_per_cache_read_input_token,
            "cost_per_input_token": cost_per_input_token,
            "cost_per_output_token": cost_per_output_token,
            "created_at": created_at,
            "deleted_at": deleted_at,
            "index": index,
            "input_tokens": input_tokens,
            "latency": latency,
            "output_tokens": output_tokens,
            "retrieval_context": retrieval_context,
            "session_id": session_id,
            "tokens": tokens,
        }

        # Filter out undefined values, keeping None and actual values
        update_data = filter_undefined(params)

        # Validate with Pydantic model
        inference_result_update = InferenceResultUpdate(**update_data)
        payload = inference_result_update.model_dump(by_alias=True, exclude_unset=True)
        response = self.__client.patch(
            f"inferenceResults/{inference_result_id}",
            json=payload,
        )
        return InferenceResult(**response.json())

    def get(self, inference_result_id: str) -> InferenceResult:
        """Retrieve an inference result log by its ID.

        Args:
            inference_result_id (str): The ID of the inference result log to retrieve

        Returns:
            InferenceResult: The retrieved inference result log object
        """
        if not is_valid_id(inference_result_id):
            raise ValueError("A valid inference result ID must be provided.")

        response = self.__client.get(f"inferenceResults/{inference_result_id}")
        return InferenceResult(**response.json())

    def delete(self, inference_result_id: str) -> None:
        """Delete an inference result log by its ID.

        Args:
            inference_result_id (str): The ID of the inference result log to delete
        """
        if not is_valid_id(inference_result_id):
            raise ValueError("A valid inference result ID must be provided.")

        self.__client.delete(f"inferenceResults/{inference_result_id}")
