"""Utilities."""

from __future__ import annotations

import abc
import typing
from typing import Any

import datasets
import pydantic
import pydantic_core.core_schema


class PydanticToHFDatasets(abc.ABC):
    """Collection of utilities for converting Pydantic models (types and instances) to HF's `datasets.Dataset`."""

    @classmethod
    def model_cls_to_features(cls, entity_type: type[pydantic.BaseModel]) -> datasets.Features:
        """Given a Pydantic model, build a `datasets.Sequence` of features that match its fields.

        :param entity_type: Entity type
        :return: `datasets.Features` instance for use in HF `datasets.Dataset`.
        """
        field_features: dict[str, datasets.Value] = {}

        for field_name, field_info in entity_type.model_fields.items():
            # field_info.annotation is e.g. str, list[str], MyNestedModel, etc.
            field_features[field_name] = cls._annotation_to_values(field_info.annotation)  # type: ignore[arg-type]

        return datasets.Features(field_features)

    @classmethod
    def _annotation_to_values(
        cls, annotation: pydantic_core.core_schema.ModelField | type
    ) -> datasets.Value | datasets.Sequence:
        """Convert a type annotation (e.g. str, list[int], MyNestedModel) to a Hugging Face `datasets` feature.

        Handles:
          - Basic python types (str, int, float, bool)
          - Lists/tuples (e.g. list[str], tuple[int], fallback for heterogeneous)
          - Dict[str, ...] => Sequence of { "key": str, "value": ... }
          - Nested Pydantic BaseModel
          - Union/Optional => fallback to string
          - Catch-all fallback => string

        :param annotation: Annotation to convert.
        :return: `datasets.Value` or `datasets.Sequence` instance generated from specified annotation.
        """
        origin = typing.get_origin(annotation)
        args = typing.get_args(annotation)

        # 1) If annotation is a subclass of BaseModel, recursively build features
        if isinstance(annotation, type) and issubclass(annotation, pydantic.BaseModel):
            return cls.model_cls_to_features(annotation)

        # 2) Handle list[...] or tuple[...]
        if origin in (list, tuple):
            if len(args) == 1:
                # e.g. list[str], tuple[int].  noqa: ERA001
                item_type = args[0]
                return datasets.Sequence(cls._annotation_to_values(item_type))
            elif len(args) > 1 and origin is tuple:
                # e.g. tuple[str, int] => fallback to storing as string
                return datasets.Sequence(datasets.Value("string"))
            else:
                # fallback
                return datasets.Sequence(datasets.Value("string"))

        # 3) Handle dict[...] => convert to sequence of { "key": str, "value": ... }
        if origin is dict:
            # Typically we have 2 type args: key_type, value_type
            if len(args) == 2:
                key_type, value_type = args
                if key_type is str:
                    # For dict[str, T], store as a sequence of key-value pairs
                    return datasets.Sequence(
                        feature=datasets.Features(
                            {"key": datasets.Value("string"), "value": cls._annotation_to_values(value_type)}
                        )
                    )
            # If untyped or non-string keys, store as JSON string
            return datasets.Value("string")

        # 4) If Union/Optional => fallback to string
        if origin == typing.Union:
            return datasets.Value("string")

        # 5) Basic primitives. Fallback: store as string.
        primitives_map: dict[type | pydantic_core.core_schema.ModelField, str] = {
            str: "string",
            int: "int32",
            float: "float32",
            bool: "bool",
        }

        return datasets.Value(primitives_map.get(annotation, "string"))

    @classmethod
    def model_to_dict(cls, model: pydantic.BaseModel | None) -> Any:
        """Given a Pydantic model instance (or nested structure), return a Python object (dict, list, etc.).

        Matchies the Hugging Face Features schema defined by `_pydantic_annotation_to_hf_value`.
        Handles:
          - BaseModel subclasses (recursively)
          - Lists / tuples
          - Dict[str, X] => list of {"key": str, "value": X}
          - Primitives
          - Union / fallback => string

        :param model: Entity to convert.
        :return: Entity as dict aligned with the `datasets.Dataset` schema generated by
            `PydanticHFConverter._model_to_features`.
        """
        # 0) If `entity` is None or truly empty
        if model is None:
            return None

        # 1) If it's an actual Pydantic model instance
        if isinstance(model, pydantic.BaseModel):
            out = {}
            # model_fields is a dict: field_name -> FieldInfo
            # We read each field's value from the instance
            for field_name, field_info in model.model_fields.items():
                annotation = field_info.annotation  # e.g. str, list[int], SubModel
                value = getattr(model, field_name)
                out[field_name] = cls._convert_value_for_dataset(value, annotation)
            return out

        # 2) If it’s not a model, we fallback to checking the type annotation dynamically or just returning the raw.
        #    But typically you'd call this function on the *top-level Pydantic model instance*.
        #    For safety:
        return model  # type: ignore[unreachable]

    @classmethod
    def _convert_value_for_dataset(cls, value: Any, annotation: Any) -> Any:
        """Recursively convert a value (with its declared annotation) to something that fits the HF dataset row format.

        Parallel to `_pydantic_annotation_to_hf_value`.

        :param value: Value to convert.
        :param annotation: Type annotation of value.
        :return Any: Converted value.
        """
        # Handle None or missing
        if value is None:
            return None

        origin = typing.get_origin(annotation)
        args = typing.get_args(annotation)

        # 1) Nested Pydantic model
        if isinstance(value, pydantic.BaseModel):
            return cls.model_to_dict(value)

        # 2) list[...] or tuple[...]
        if origin in (list, tuple):
            # If it's actually a list/tuple, recursively process items
            if isinstance(value, list | tuple):
                if len(args) == 1:
                    # e.g. list[str], list[SomeSubModel].
                    item_type = args[0]
                    return [cls._convert_value_for_dataset(v, item_type) for v in value]
                elif len(args) > 1 and origin is tuple:
                    # tuple[str, int, ...] => fallback to string or handle partial
                    return [str(v) for v in value]
                else:
                    # fallback
                    return [str(v) for v in value]
            else:
                # If the actual data isn't a list/tuple, fallback
                return str(value)

        # 3) dict[str, X] => store as list of { "key": str, "value": X }
        if origin is dict:
            # Check if the actual data is indeed a dict
            if isinstance(value, dict):
                if len(args) == 2:
                    key_type, val_type = args
                    # only handle str-key dicts
                    if key_type is str:
                        kv_list = []
                        for k, v in value.items():
                            # Convert each item recursively
                            converted_val = cls._convert_value_for_dataset(v, val_type)
                            kv_list.append({"key": str(k), "value": converted_val})
                        return kv_list
                # else fallback -> store entire dict as a string
                return str(value)
            else:
                # Not actually a dict
                return str(value)

        # 4) Unions / Optionals => fallback to string (or refine if you want)
        if origin == typing.Union:
            # Typically means `Optional[X]` or `Union[X, Y]`.
            # We'll just store it as string:
            return str(value)

        # 5) If annotation is a direct primitive type
        #    Just return the value as-is
        if annotation in (str, int, float, bool):
            return value

        # 6) If it's a fallback -> store as string
        return str(value)
