import base64
import bz2
import functools
import json
import logging
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Callable, Generic, Optional, Type, TypeVar

import pydantic

from datahub.configuration.common import ConfigModel
from datahub.emitter.mce_builder import parse_ts_millis
from datahub.metadata.schema_classes import (
    DatahubIngestionCheckpointClass,
    IngestionCheckpointStateClass,
)

logger: logging.Logger = logging.getLogger(__name__)

DEFAULT_MAX_STATE_SIZE = 2**22  # 4MB


class CheckpointStateBase(ConfigModel):
    """
    Base class for ingestion checkpoint state.
    NOTE: We use the pydantic based ConfigModel as base here so that
    we can leverage built-in functionality for including/excluding fields in the
    serialization along with potential validation for specific sources.
    """

    version: str = pydantic.Field(default="1.0")
    serde: str = pydantic.Field(default="base85-bz2-json")

    def to_bytes(
        self,
        compressor: Callable[[bytes], bytes] = functools.partial(
            bz2.compress, compresslevel=9
        ),
        max_allowed_state_size: int = DEFAULT_MAX_STATE_SIZE,
    ) -> bytes:
        """
        NOTE: Binary compression cannot be turned on yet as the current MCPs encode the GeneralizedAspect
        payload using Json encoding which does not support bytes type data. For V1, we go with the utf-8 encoding.
        This also means that double serialization also is not possible to encode version and serde separate from the
        binary state payload. Binary content-type needs to be supported for encoding the GenericAspect to do this.
        """

        if self.serde == "utf-8":
            encoded_bytes = CheckpointStateBase._to_bytes_utf8(self)
        elif self.serde == "base85":
            # The original base85 implementation used pickle, which would cause
            # issues with deserialization if we ever changed the state class definition.
            raise ValueError(
                "Cannot write base85 encoded bytes. Use base85-bz2-json instead."
            )
        elif self.serde == "base85-bz2-json":
            encoded_bytes = CheckpointStateBase._to_bytes_base85_json(self, compressor)
        else:
            raise ValueError(f"Unknown serde: {self.serde}")

        if len(encoded_bytes) > max_allowed_state_size:
            raise ValueError(
                f"The state size has exceeded the max_allowed_state_size of {max_allowed_state_size}"
            )

        return encoded_bytes

    @staticmethod
    def _to_bytes_utf8(model: ConfigModel) -> bytes:
        pydantic_json = model.model_dump_json(exclude={"version", "serde"})
        # We decode and re-encode so that Python's default whitespace is included.
        # This is purely to keep tests consistent as we migrate to pydantic v2,
        # and can be removed once we're fully migrated.
        return json.dumps(json.loads(pydantic_json)).encode("utf-8")

    @staticmethod
    def _to_bytes_base85_json(
        model: ConfigModel, compressor: Callable[[bytes], bytes]
    ) -> bytes:
        return base64.b85encode(compressor(CheckpointStateBase._to_bytes_utf8(model)))

    def prepare_for_commit(self) -> None:
        """
        Perform any pre-commit steps, such as deduplication, custom-compression across data etc.
        """
        pass


StateType = TypeVar("StateType", bound=CheckpointStateBase)


@dataclass
class Checkpoint(Generic[StateType]):
    """
    Ingestion Run Checkpoint class. This is a more convenient abstraction for use in the python ingestion code,
    providing a strongly typed state object vs the opaque blob in the PDL, and the config persisted as the first-class
    ConfigModel object.
    """

    job_name: str
    pipeline_name: str
    run_id: str
    state: StateType

    @classmethod
    def create_from_checkpoint_aspect(
        cls,
        job_name: str,
        checkpoint_aspect: Optional[DatahubIngestionCheckpointClass],
        state_class: Type[StateType],
    ) -> Optional["Checkpoint[StateType]"]:
        if checkpoint_aspect is None:
            return None
        else:
            try:
                if checkpoint_aspect.state.serde == "utf-8":
                    state_obj = Checkpoint._from_utf8_bytes(
                        checkpoint_aspect, state_class
                    )
                elif checkpoint_aspect.state.serde == "base85":
                    raise ValueError(
                        "The base85 encoding for stateful ingestion has been removed for security reasons. "
                        "You may need to temporarily set `ignore_previous_checkpoint` to true to ignore the outdated checkpoint object."
                    )
                elif checkpoint_aspect.state.serde == "base85-bz2-json":
                    state_obj = Checkpoint._from_base85_json_bytes(
                        checkpoint_aspect,
                        functools.partial(bz2.decompress),
                        state_class,
                    )
                else:
                    raise ValueError(f"Unknown serde: {checkpoint_aspect.state.serde}")
            except Exception as e:
                logger.error(
                    f"Failed to construct checkpoint class from checkpoint aspect: {e}"
                )
                raise e
            else:
                # Construct the deserialized Checkpoint object from the raw aspect.
                checkpoint = cls(
                    job_name=job_name,
                    pipeline_name=checkpoint_aspect.pipelineName,
                    run_id=checkpoint_aspect.runId,
                    state=state_obj,
                )
                logger.info(
                    f"Successfully constructed last checkpoint state for job {job_name} "
                    f"with timestamp {parse_ts_millis(checkpoint_aspect.timestampMillis)}"
                )
                return checkpoint
        return None

    @staticmethod
    def _from_utf8_bytes(
        checkpoint_aspect: DatahubIngestionCheckpointClass,
        state_class: Type[StateType],
    ) -> StateType:
        state_as_dict = (
            json.loads(checkpoint_aspect.state.payload.decode("utf-8"))
            if checkpoint_aspect.state.payload is not None
            else {}
        )
        state_as_dict["version"] = checkpoint_aspect.state.formatVersion
        state_as_dict["serde"] = checkpoint_aspect.state.serde
        return state_class.parse_obj(state_as_dict)

    @staticmethod
    def _from_base85_json_bytes(
        checkpoint_aspect: DatahubIngestionCheckpointClass,
        decompressor: Callable[[bytes], bytes],
        state_class: Type[StateType],
    ) -> StateType:
        state_uncompressed = decompressor(
            base64.b85decode(checkpoint_aspect.state.payload)
            if checkpoint_aspect.state.payload is not None
            else b"{}"
        )
        state_as_dict = json.loads(state_uncompressed.decode("utf-8"))
        state_as_dict["version"] = checkpoint_aspect.state.formatVersion
        state_as_dict["serde"] = checkpoint_aspect.state.serde
        return state_class.parse_obj(state_as_dict)

    def to_checkpoint_aspect(
        self, max_allowed_state_size: int
    ) -> Optional[DatahubIngestionCheckpointClass]:
        try:
            checkpoint_state = IngestionCheckpointStateClass(
                formatVersion=self.state.version,
                serde=self.state.serde,
                payload=self.state.to_bytes(
                    max_allowed_state_size=max_allowed_state_size
                ),
            )
            checkpoint_aspect = DatahubIngestionCheckpointClass(
                timestampMillis=int(datetime.now(tz=timezone.utc).timestamp() * 1000),
                pipelineName=self.pipeline_name,
                platformInstanceId="",
                runId=self.run_id,
                config="",
                state=checkpoint_state,
            )
            return checkpoint_aspect
        except Exception as e:
            logger.error(
                "Failed to construct the checkpoint aspect from checkpoint object", e
            )

        return None

    def prepare_for_commit(self) -> None:
        self.state.prepare_for_commit()
