import logging
import threading
from dataclasses import dataclass, field
from typing import Any, Dict, Optional

import boto3
from humanfriendly import format_timespan
from pydantic import Field, field_validator
from pyiceberg.catalog import BOTOCORE_SESSION, Catalog, load_catalog
from pyiceberg.catalog.glue import (
    GLUE_ACCESS_KEY_ID,
    GLUE_PROFILE_NAME,
    GLUE_REGION,
    GLUE_SECRET_ACCESS_KEY,
    GLUE_SESSION_TOKEN,
)
from pyiceberg.catalog.rest import RestCatalog
from pyiceberg.io import (
    AWS_ACCESS_KEY_ID,
    AWS_REGION,
    AWS_ROLE_ARN,
    AWS_SECRET_ACCESS_KEY,
    AWS_SESSION_TOKEN,
)
from pyiceberg.utils.properties import get_first_property_value
from requests.adapters import HTTPAdapter
from sortedcontainers import SortedList
from urllib3.util import Retry

from datahub.configuration.common import AllowDenyPattern, ConfigModel
from datahub.configuration.source_common import DatasetSourceConfigMixin
from datahub.ingestion.source.state.stale_entity_removal_handler import (
    StaleEntityRemovalSourceReport,
    StatefulStaleMetadataRemovalConfig,
)
from datahub.ingestion.source.state.stateful_ingestion_base import (
    StatefulIngestionConfigBase,
)
from datahub.ingestion.source_config.operation_config import (
    OperationConfig,
    is_profiling_enabled,
)
from datahub.utilities.lossy_collections import LossyList
from datahub.utilities.stats_collections import TopKDict, int_top_k_dict

logger = logging.getLogger(__name__)

DEFAULT_REST_TIMEOUT = 120
DEFAULT_REST_RETRY_POLICY = {"total": 3, "backoff_factor": 0.1}

GLUE_ROLE_ARN = "glue.role-arn"


class TimeoutHTTPAdapter(HTTPAdapter):
    def __init__(self, *args, **kwargs):
        if "timeout" in kwargs:
            self.timeout = kwargs["timeout"]
            del kwargs["timeout"]
        super().__init__(*args, **kwargs)

    def send(self, request, *args, **kwargs):
        timeout = kwargs.get("timeout")
        if timeout is None and hasattr(self, "timeout"):
            kwargs["timeout"] = self.timeout
        return super().send(request, *args, **kwargs)


class IcebergProfilingConfig(ConfigModel):
    enabled: bool = Field(
        default=False,
        description="Whether profiling should be done.",
    )
    include_field_null_count: bool = Field(
        default=True,
        description="Whether to profile for the number of nulls for each column.",
    )
    include_field_min_value: bool = Field(
        default=True,
        description="Whether to profile for the min value of numeric columns.",
    )
    include_field_max_value: bool = Field(
        default=True,
        description="Whether to profile for the max value of numeric columns.",
    )
    operation_config: OperationConfig = Field(
        default_factory=OperationConfig,
        description="Experimental feature. To specify operation configs.",
    )
    # Stats we cannot compute without looking at data
    # include_field_mean_value: bool = True
    # include_field_median_value: bool = True
    # include_field_stddev_value: bool = True
    # include_field_quantiles: bool = False
    # include_field_distinct_value_frequencies: bool = False
    # include_field_histogram: bool = False
    # include_field_sample_values: bool = True


class IcebergSourceConfig(StatefulIngestionConfigBase, DatasetSourceConfigMixin):
    # Override the stateful_ingestion config param with the Iceberg custom stateful ingestion config in the IcebergSourceConfig
    stateful_ingestion: Optional[StatefulStaleMetadataRemovalConfig] = Field(
        default=None, description="Iceberg Stateful Ingestion Config."
    )
    # The catalog configuration is using a dictionary to be open and flexible.  All the keys and values are handled by pyiceberg.  This will future-proof any configuration change done by pyiceberg.
    catalog: Dict[str, Dict[str, Any]] = Field(
        description="Catalog configuration where to find Iceberg tables.  Only one catalog specification is supported.  The format is the same as [pyiceberg's catalog configuration](https://py.iceberg.apache.org/configuration/), where the catalog name is specified as the object name and attributes are set as key-value pairs.",
    )
    table_pattern: AllowDenyPattern = Field(
        default=AllowDenyPattern.allow_all(),
        description="Regex patterns for tables to filter in ingestion.",
    )
    namespace_pattern: AllowDenyPattern = Field(
        default=AllowDenyPattern.allow_all(),
        description="Regex patterns for namespaces to filter in ingestion.",
    )
    user_ownership_property: Optional[str] = Field(
        default="owner",
        description="Iceberg table property to look for a `CorpUser` owner.  Can only hold a single user value.  If property has no value, no owner information will be emitted.",
    )
    group_ownership_property: Optional[str] = Field(
        default=None,
        description="Iceberg table property to look for a `CorpGroup` owner.  Can only hold a single group value.  If property has no value, no owner information will be emitted.",
    )
    profiling: IcebergProfilingConfig = IcebergProfilingConfig()
    processing_threads: int = Field(
        default=1, description="How many threads will be processing tables"
    )

    @field_validator("catalog", mode="before")
    @classmethod
    def handle_deprecated_catalog_format(cls, value):
        # Once support for deprecated format is dropped, we can remove this validator.
        if (
            isinstance(value, dict)
            and "name" in value
            and "type" in value
            and "config" in value
        ):
            # This looks like the deprecated format
            logger.warning(
                "The catalog configuration format you are using is deprecated and will be removed in a future version. Please update to the new format.",
            )
            catalog_name = value["name"]
            catalog_type = value["type"]
            catalog_config = value["config"]
            new_catalog_config = {
                catalog_name: {"type": catalog_type, **catalog_config}
            }
            return new_catalog_config
        # In case the input is already the new format or is invalid
        return value

    @field_validator("catalog", mode="after")
    @classmethod
    def validate_catalog_size(cls, value):
        if len(value) != 1:
            raise ValueError("The catalog must contain exactly one entry.")

        # Retrieve the dict associated with the one catalog entry
        catalog_name, catalog_config = next(iter(value.items()))

        # Check if that dict is not empty
        if not catalog_config or not isinstance(catalog_config, dict):
            raise ValueError(
                f"The catalog configuration for '{catalog_name}' must not be empty and should be a dictionary with at least one key-value pair."
            )

        return value

    def is_profiling_enabled(self) -> bool:
        return self.profiling.enabled and is_profiling_enabled(
            self.profiling.operation_config
        )

    def _custom_glue_catalog_handling(self, catalog_config: Dict[str, Any]) -> None:
        role_to_assume = get_first_property_value(
            catalog_config, GLUE_ROLE_ARN, AWS_ROLE_ARN
        )
        if role_to_assume:
            logger.debug(
                "Recognized role ARN in glue catalog config, attempting to workaround pyiceberg limitation in role assumption for the glue client"
            )
            session = boto3.Session(
                profile_name=catalog_config.get(GLUE_PROFILE_NAME),
                region_name=get_first_property_value(
                    catalog_config, GLUE_REGION, AWS_REGION
                ),
                botocore_session=catalog_config.get(BOTOCORE_SESSION),
                aws_access_key_id=get_first_property_value(
                    catalog_config, GLUE_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID
                ),
                aws_secret_access_key=get_first_property_value(
                    catalog_config, GLUE_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY
                ),
                aws_session_token=get_first_property_value(
                    catalog_config, GLUE_SESSION_TOKEN, AWS_SESSION_TOKEN
                ),
            )

            sts_client = session.client("sts")
            identity = sts_client.get_caller_identity()
            logger.debug(
                f"Authenticated as {identity['Arn']}, attempting to assume a role: {role_to_assume}"
            )

            current_role_arn = None
            try:
                if ":assumed-role/" in identity["Arn"]:
                    current_role_arn = (
                        "/".join(identity["Arn"].split("/")[0:-1])
                        .replace(":assumed-role/", ":role/")
                        .replace("arn:aws:sts", "arn:aws:iam")
                    )
                    logger.debug(f"Deducted current role: {current_role_arn}")
            except Exception as e:
                logger.warning(
                    "We couldn't convert currently assumed role to 'role' format so that we could compare "
                    f"it with the target role, will try to assume the target role nonetheless, exception: {e}"
                )

            if current_role_arn == role_to_assume:
                logger.debug(
                    "Current role and the role we wanted to assume are the same, continuing without further assumption steps"
                )
            else:
                logger.debug(f"Assuming the role {role_to_assume}")
                # below might fail if such duration is not allowed per policies
                try:
                    response = sts_client.assume_role(
                        RoleArn=role_to_assume,
                        RoleSessionName="session",
                        DurationSeconds=43200,
                    )
                except sts_client.exceptions.ClientError:
                    # Fallback to default duration
                    response = sts_client.assume_role(
                        RoleArn=role_to_assume, RoleSessionName="session"
                    )
                logger.debug(f"Assumed role: {response['AssumedRoleUser']}")
                creds = response["Credentials"]
                catalog_config[GLUE_ACCESS_KEY_ID] = creds["AccessKeyId"]
                catalog_config[GLUE_SECRET_ACCESS_KEY] = creds["SecretAccessKey"]
                catalog_config[GLUE_SESSION_TOKEN] = creds["SessionToken"]

    def get_catalog(self) -> Catalog:
        """Returns the Iceberg catalog instance as configured by the `catalog` dictionary.

        Returns:
            Catalog: Iceberg catalog instance.
        """
        if not self.catalog:
            raise ValueError("No catalog configuration found")

        # Retrieve the dict associated with the one catalog entry
        catalog_name, catalog_config = next(iter(self.catalog.items()))
        logger.debug("Initializing the catalog %s", catalog_name)

        # workaround pyiceberg 0.10.0 issue with ignoring role assumption for glue catalog,
        # remove this code once pyiceberg is fixed, raised issue: https://github.com/apache/iceberg-python/issues/2747
        if catalog_config.get("type") == "glue":
            self._custom_glue_catalog_handling(catalog_config)

        catalog = load_catalog(name=catalog_name, **catalog_config)
        if isinstance(catalog, RestCatalog):
            logger.debug(
                "Recognized REST catalog type being configured, attempting to configure HTTP Adapter for the session"
            )
            retry_policy: Dict[str, Any] = DEFAULT_REST_RETRY_POLICY.copy()
            retry_policy.update(catalog_config.get("connection", {}).get("retry", {}))
            retries = Retry(**retry_policy)
            logger.debug(f"Retry policy to be set: {retry_policy}")
            timeout = catalog_config.get("connection", {}).get(
                "timeout", DEFAULT_REST_TIMEOUT
            )
            logger.debug(f"Timeout to be set: {timeout}")
            catalog._session.mount(
                "http://", TimeoutHTTPAdapter(timeout=timeout, max_retries=retries)
            )
            catalog._session.mount(
                "https://", TimeoutHTTPAdapter(timeout=timeout, max_retries=retries)
            )
        return catalog


class TopTableTimings:
    _VALUE_FIELD: str = "timing"
    top_entites: SortedList
    _size: int

    def __init__(self, size: int = 10):
        self._size = size
        self.top_entites = SortedList(key=lambda x: -x.get(self._VALUE_FIELD, 0))
        self._lock = threading.Lock()

    def add(self, entity: Dict[str, Any]) -> None:
        if self._VALUE_FIELD not in entity:
            return
        with self._lock:
            self.top_entites.add(entity)
            if len(self.top_entites) > self._size:
                self.top_entites.pop()

    def __str__(self) -> str:
        with self._lock:
            if len(self.top_entites) == 0:
                return "no timings reported"
            return str(list(self.top_entites))


class TimingClass:
    times: SortedList

    def __init__(self):
        self.times = SortedList()
        self._lock = threading.Lock()

    def add_timing(self, t: float) -> None:
        with self._lock:
            self.times.add(t)

    def __str__(self) -> str:
        with self._lock:
            if len(self.times) == 0:
                return "no timings reported"
            total = sum(self.times)
            avg = total / len(self.times)
            return str(
                {
                    "average_time": format_timespan(avg, detailed=True, max_units=3),
                    "min_time": format_timespan(
                        self.times[0], detailed=True, max_units=3
                    ),
                    "max_time": format_timespan(
                        self.times[-1], detailed=True, max_units=3
                    ),
                    # total_time does not provide correct information in case we run in more than 1 thread
                    "total_time": format_timespan(total, detailed=True, max_units=3),
                }
            )


@dataclass
class IcebergSourceReport(StaleEntityRemovalSourceReport):
    tables_scanned: int = 0
    entities_profiled: int = 0
    filtered: LossyList[str] = field(default_factory=LossyList)
    load_table_timings: TimingClass = field(default_factory=TimingClass)
    processing_table_timings: TimingClass = field(default_factory=TimingClass)
    profiling_table_timings: TimingClass = field(default_factory=TimingClass)
    tables_load_timings: TopTableTimings = field(default_factory=TopTableTimings)
    tables_profile_timings: TopTableTimings = field(default_factory=TopTableTimings)
    tables_process_timings: TopTableTimings = field(default_factory=TopTableTimings)
    listed_namespaces: int = 0
    total_listed_tables: int = 0
    tables_listed_per_namespace: TopKDict[str, int] = field(
        default_factory=int_top_k_dict
    )

    def report_listed_tables_for_namespace(
        self, namespace: str, no_tables: int
    ) -> None:
        self.tables_listed_per_namespace[namespace] = no_tables
        self.total_listed_tables += no_tables

    def report_no_listed_namespaces(self, amount: int) -> None:
        self.listed_namespaces = amount

    def report_table_scanned(self, name: str) -> None:
        self.tables_scanned += 1

    def report_dropped(self, ent_name: str) -> None:
        self.filtered.append(ent_name)

    def report_table_load_time(
        self, t: float, table_name: str, table_metadata_location: str
    ) -> None:
        self.load_table_timings.add_timing(t)
        self.tables_load_timings.add(
            {"table": table_name, "timing": t, "metadata_file": table_metadata_location}
        )

    def report_table_processing_time(
        self, t: float, table_name: str, table_metadata_location: str
    ) -> None:
        self.processing_table_timings.add_timing(t)
        self.tables_process_timings.add(
            {"table": table_name, "timing": t, "metadata_file": table_metadata_location}
        )

    def report_table_profiling_time(
        self, t: float, table_name: str, table_metadata_location: str
    ) -> None:
        self.profiling_table_timings.add_timing(t)
        self.tables_profile_timings.add(
            {"table": table_name, "timing": t, "metadata_file": table_metadata_location}
        )
