# SPDX-FileCopyrightText: 2024 UL Research Institutes
# SPDX-License-Identifier: Apache-2.0
"""The request schemas describe the information that you need to provide when creating
new instances of the core types.

For example, requests do not have
``.id`` fields because these are assigned by the platform when the resource
is created. Similarly, if a resource depends on an instance of another
resource, the request will refer to the dependency by its ID, while the core
resource will include the full dependency object as a sub-resource. The
``create`` endpoints take a request as input and return a full core resource
in response.
"""


from datetime import datetime
from typing import Any, Optional, Union

import pydantic

from .base import DyffDefaultSerializers
from .platform import (
    AnalysisBase,
    DatasetBase,
    DataView,
    DocumentationBase,
    EvaluationBase,
    InferenceServiceBase,
    InferenceSessionBase,
    Labeled,
    MethodBase,
    ModelSpec,
    ModuleBase,
    ReportBase,
    TagBase,
)
from .version import SchemaVersion


class DyffRequestDefaultValidators(DyffDefaultSerializers):
    """This must be the base class for *all* request models in the Dyff schema.

    Adds a root validator to ensure that all user-provided datetime fields have a
    timezone set. Timezones will be converted to UTC once the data enters the platform,
    but we allow requests to have non-UTC timezones for user convenience.
    """

    @pydantic.root_validator
    def _require_datetime_timezone_aware(cls, values):
        for k, v in values.items():
            if isinstance(v, datetime):
                if v.tzinfo is None:
                    raise ValueError(f"{cls.__qualname__}.{k}: timezone not set")
        return values


class DyffRequestBase(SchemaVersion, DyffRequestDefaultValidators):
    pass


class DyffEntityCreateRequest(DyffRequestBase):
    account: str = pydantic.Field(description="Account that owns the entity")


class AnalysisCreateRequest(DyffEntityCreateRequest, AnalysisBase):
    """An Analysis transforms Datasets, Evaluations, and Measurements into new
    Measurements or SafetyCases."""

    method: str = pydantic.Field(description="Method ID")


class DatasetCreateRequest(DyffEntityCreateRequest, DatasetBase):
    pass


class DocumentationEditRequest(DyffRequestBase, DocumentationBase):
    pass


class InferenceServiceCreateRequest(DyffEntityCreateRequest, InferenceServiceBase):
    model: Optional[str] = pydantic.Field(
        default=None, description="ID of Model backing the service, if applicable"
    )


class InferenceSessionCreateRequest(DyffEntityCreateRequest, InferenceSessionBase):
    inferenceService: str = pydantic.Field(description="InferenceService ID")


class InferenceSessionTokenCreateRequest(DyffRequestBase):
    expires: Optional[datetime] = pydantic.Field(
        default=None,
        description="Expiration time of the token. Must be <= expiration time"
        " of session. Default: expiration time of session.",
    )


class EvaluationInferenceSessionRequest(InferenceSessionBase):
    inferenceService: str = pydantic.Field(description="InferenceService ID")


class EvaluationCreateRequest(DyffEntityCreateRequest, EvaluationBase):
    """A description of how to run an InferenceService on a Dataset to obtain a set of
    evaluation results."""

    inferenceSession: Optional[EvaluationInferenceSessionRequest] = pydantic.Field(
        default=None,
        description="Specification of the InferenceSession that will perform inference for the evaluation.",
    )

    inferenceSessionReference: Optional[str] = pydantic.Field(
        default=None,
        description="The ID of a running inference session that will be used"
        " for the evaluation, instead of starting a new one.",
    )

    @pydantic.root_validator
    def check_session_exactly_one(cls, values):
        session = values.get("inferenceSession") is not None
        session_ref = values.get("inferenceSessionReference") is not None
        if not (session ^ session_ref):
            raise ValueError(
                "must specify exactly one of {inferenceSession, inferenceSessionReference}"
            )
        return values


class MethodCreateRequest(DyffEntityCreateRequest, MethodBase):
    pass


class ModelCreateRequest(DyffEntityCreateRequest, ModelSpec):
    pass


class ModuleCreateRequest(DyffEntityCreateRequest, ModuleBase):
    pass


class ReportCreateRequest(DyffEntityCreateRequest, ReportBase):
    """A Report transforms raw model outputs into some useful statistics.

    .. deprecated:: 0.8.0

        Report functionality has been refactored into the
        Method/Measurement/Analysis apparatus. Creation of new Reports is
        disabled.
    """

    datasetView: Optional[Union[str, DataView]] = pydantic.Field(
        default=None,
        description="View of the input dataset required by the report (e.g., ground-truth labels).",
    )

    evaluationView: Optional[Union[str, DataView]] = pydantic.Field(
        default=None,
        description="View of the evaluation output data required by the report.",
    )


class TagCreateRequest(DyffRequestBase, TagBase):
    pass


class LabelUpdateRequest(DyffRequestBase, Labeled):
    pass


# Note: Query requests, as they currently exist, don't work with Versioned
# because fastapi will assign None to every field that the client doesn't
# specify. I think it's not that important, because all of the query parameters
# will always be optional. There could be a problem if the semantics of a
# name change, but let's just not do that!
class DyffEntityQueryRequest(DyffRequestDefaultValidators):
    id: Optional[str] = pydantic.Field(default=None)
    account: Optional[str] = pydantic.Field(default=None)
    status: Optional[str] = pydantic.Field(default=None)
    reason: Optional[str] = pydantic.Field(default=None)
    labels: Optional[str] = pydantic.Field(
        default=None, description="Labels dict represented as a JSON string."
    )

    def dict(self, exclude_unset=True, **kwargs) -> dict[str, Any]:
        return super().dict(exclude_unset=exclude_unset, **kwargs)

    def json(self, exclude_unset=True, **kwargs) -> Any:
        return super().json(exclude_unset=exclude_unset, **kwargs)


class _AnalysisProductQueryRequest(DyffEntityQueryRequest):
    method: Optional[str] = pydantic.Field(default=None)
    methodName: Optional[str] = pydantic.Field(default=None)
    dataset: Optional[str] = pydantic.Field(default=None)
    evaluation: Optional[str] = pydantic.Field(default=None)
    inferenceService: Optional[str] = pydantic.Field(default=None)
    model: Optional[str] = pydantic.Field(default=None)
    inputsAnyOf: Optional[str] = pydantic.Field(default=None)


class AuditQueryRequest(DyffEntityQueryRequest):
    name: Optional[str] = pydantic.Field(default=None)


class DatasetQueryRequest(DyffEntityQueryRequest):
    name: Optional[str] = pydantic.Field(default=None)


class EvaluationQueryRequest(DyffEntityQueryRequest):
    dataset: Optional[str] = pydantic.Field(default=None)
    inferenceService: Optional[str] = pydantic.Field(default=None)
    inferenceServiceName: Optional[str] = pydantic.Field(default=None)
    model: Optional[str] = pydantic.Field(default=None)
    modelName: Optional[str] = pydantic.Field(default=None)


class InferenceServiceQueryRequest(DyffEntityQueryRequest):
    name: Optional[str] = pydantic.Field(default=None)
    model: Optional[str] = pydantic.Field(default=None)
    modelName: Optional[str] = pydantic.Field(default=None)


class InferenceSessionQueryRequest(DyffEntityQueryRequest):
    name: Optional[str] = pydantic.Field(default=None)
    inferenceService: Optional[str] = pydantic.Field(default=None)
    inferenceServiceName: Optional[str] = pydantic.Field(default=None)
    model: Optional[str] = pydantic.Field(default=None)
    modelName: Optional[str] = pydantic.Field(default=None)


class MeasurementQueryRequest(_AnalysisProductQueryRequest):
    pass


class MethodQueryRequest(DyffEntityQueryRequest):
    name: Optional[str] = pydantic.Field(default=None)
    outputKind: Optional[str] = pydantic.Field(default=None)


class ModelQueryRequest(DyffEntityQueryRequest):
    name: Optional[str] = pydantic.Field(default=None)


class ModuleQueryRequest(DyffEntityQueryRequest):
    name: Optional[str] = pydantic.Field(default=None)


class ReportQueryRequest(DyffEntityQueryRequest):
    report: Optional[str] = pydantic.Field(default=None)
    dataset: Optional[str] = pydantic.Field(default=None)
    evaluation: Optional[str] = pydantic.Field(default=None)
    inferenceService: Optional[str] = pydantic.Field(default=None)
    model: Optional[str] = pydantic.Field(default=None)


class SafetyCaseQueryRequest(_AnalysisProductQueryRequest):
    pass


__all__ = [
    "AnalysisCreateRequest",
    "AuditQueryRequest",
    "DyffEntityCreateRequest",
    "DyffEntityQueryRequest",
    "DyffRequestBase",
    "DyffRequestDefaultValidators",
    "DatasetCreateRequest",
    "DatasetQueryRequest",
    "DocumentationEditRequest",
    "EvaluationCreateRequest",
    "EvaluationQueryRequest",
    "EvaluationInferenceSessionRequest",
    "InferenceServiceCreateRequest",
    "InferenceServiceQueryRequest",
    "InferenceSessionCreateRequest",
    "InferenceSessionQueryRequest",
    "InferenceSessionTokenCreateRequest",
    "LabelUpdateRequest",
    "MeasurementQueryRequest",
    "MethodCreateRequest",
    "MethodQueryRequest",
    "ModelCreateRequest",
    "ModelQueryRequest",
    "ModuleCreateRequest",
    "ModuleQueryRequest",
    "ReportCreateRequest",
    "ReportQueryRequest",
    "SafetyCaseQueryRequest",
    "TagCreateRequest",
]
