# Copyright 2024 Superlinked, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from beartype.typing import Generic, cast

from superlinked.framework.common.data_types import Json
from superlinked.framework.common.parser.data_parser import DataParser
from superlinked.framework.common.parser.exception import (
    MissingCreatedAtException,
    MissingIdException,
)
from superlinked.framework.common.parser.parsed_schema import (
    EventParsedSchema,
    ParsedSchema,
    ParsedSchemaField,
)
from superlinked.framework.common.schema.event_schema_object import (
    EventSchemaObject,
    SchemaReference,
)
from superlinked.framework.common.schema.id_schema_object import IdSchemaObjectT
from superlinked.framework.common.schema.schema_object import SFT, SchemaField
from superlinked.framework.common.util.dot_separated_path_util import (
    DotSeparatedPathUtil,
    ValuedDotSeparatedPath,
)


class JsonParser(Generic[IdSchemaObjectT], DataParser[IdSchemaObjectT, Json]):
    """
    JsonParser gets a `Json` object and using `str` based dot separated path mapping
    it transforms the `Json` to a desired schema.
    """

    def unmarshal(self, data: Json | list[Json]) -> list[ParsedSchema]:
        """
        Parses the given Json into a list of ParsedSchema objects according to the defined schema and mapping.

        Args:
            data (Json): The Json representation of your data.

        Returns:
            list[ParsedSchema]: A list of ParsedSchema objects that will be processed by the spaces.
        """

        if not isinstance(data, list):
            data = [data]
        return [self._unmarshal_single(json_data) for json_data in data]

    def _unmarshal_single(self, json_data: Json) -> EventParsedSchema | ParsedSchema:
        id_ = self.__ensure_id(json_data)
        parsed_fields: list[ParsedSchemaField] = [
            ParsedSchemaField.from_schema_field(field, parsed_value)
            for field, parsed_value in [
                (field, self._parse_schema_field_value(field, json_data))
                for field in self._schema._get_schema_fields()
            ]
            if parsed_value is not None
        ]
        if self._is_event_data_parser:
            return EventParsedSchema(
                self._schema,
                id_,
                parsed_fields,
                self.__ensure_created_at(json_data),
            )

        return ParsedSchema(self._schema, id_, parsed_fields)

    def _marshal(
        self,
        parsed_schemas: list[ParsedSchema],
    ) -> list[Json]:
        """
        Converts a ParsedSchema objects back into a list of Json objects.
        You can use this functionality to check, if your mapping was defined properly.

        Args:
            parsed_schemas (list[ParsedSchema]): ParserSchema in a list that you can
                retrieve after unmarshalling your `Json`.

        Returns:
            list[Json]: List of Json representation of the schemas.
        """
        return [
            self.__construct_json(list_of_schema_fields)
            for list_of_schema_fields in [
                self.__get_all_fields_from_parsed_schema(parsed_schema)
                for parsed_schema in parsed_schemas
            ]
        ]

    def __construct_json(self, parsed_schema_field: list[ParsedSchemaField]) -> Json:
        return DotSeparatedPathUtil.to_dict(
            [
                ValuedDotSeparatedPath(
                    self._get_path(field.schema_field),
                    field.value,
                )
                for field in parsed_schema_field
            ]
        )

    def __ensure_id(self, data: Json) -> str:
        id_ = DotSeparatedPathUtil.get(data, self._id_name)
        if not self._is_id_value_valid(id_):
            raise MissingIdException(
                "The mandatory id field is missing from the input object."
            )
        return str(id_)

    def __ensure_created_at(self, data: Json) -> int:
        created_at = DotSeparatedPathUtil.get(data, self._created_at_name)
        if not self._is_created_at_value_valid(created_at):
            raise MissingCreatedAtException(
                "The mandatory created_at field is missing from the input object."
            )
        return cast(int, created_at)

    def _parse_schema_field_value(
        self, field: SchemaField[SFT], data: Json
    ) -> SFT | None:
        path: str = self._get_path(field)
        parsed_schema_field_value = DotSeparatedPathUtil.get(data, path)
        if isinstance(field, SchemaReference):
            parsed_schema_field_value = cast(SFT, str(parsed_schema_field_value))
        return parsed_schema_field_value

    def __get_all_fields_from_parsed_schema(
        self, parsed_schema: ParsedSchema
    ) -> list[ParsedSchemaField]:
        return (
            list(parsed_schema.fields)
            + [ParsedSchemaField.from_schema_field(self._schema.id, parsed_schema.id_)]
            + (
                [
                    ParsedSchemaField.from_schema_field(
                        cast(EventSchemaObject, self._schema).created_at,
                        cast(EventParsedSchema, parsed_schema).created_at,
                    )
                ]
                if self._is_event_data_parser
                else []
            )
        )
