from __future__ import annotations

import logging
from functools import wraps
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Dict,
    Optional,
    Type,
    Union,
)

from great_expectations.compatibility.pyspark import functions as F
from great_expectations.compatibility.sqlalchemy import sqlalchemy as sa
from great_expectations.core.metric_domain_types import MetricDomainTypes
from great_expectations.core.metric_function_types import (
    MetricPartialFunctionTypes,
)
from great_expectations.execution_engine import (
    ExecutionEngine,
    PandasExecutionEngine,
    SparkDFExecutionEngine,
    SqlAlchemyExecutionEngine,
)
from great_expectations.expectations.metrics.metric_provider import (
    metric_partial,
)
from great_expectations.expectations.metrics.util import (
    get_dbms_compatible_metric_domain_kwargs,
)

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
    from great_expectations.compatibility import sqlalchemy


def column_function_partial(  # noqa: C901, PLR0915 # FIXME CoP
    engine: Type[ExecutionEngine],
    partial_fn_type: Optional[MetricPartialFunctionTypes] = None,
    **kwargs,
):
    """Provides engine-specific support for authoring a metric_fn with a simplified signature.

    A metric function that is decorated as a column_function_partial will be called with the engine-specific column type
    and any value_kwargs associated with the Metric for which the provider function is being declared.

    Args:
        engine:
        partial_fn_type:
        **kwargs:

    Returns:
        An annotated metric_function which will be called with a simplified signature.

    """  # noqa: E501 # FIXME CoP
    domain_type = MetricDomainTypes.COLUMN
    if issubclass(engine, PandasExecutionEngine):
        if partial_fn_type is None:
            partial_fn_type = MetricPartialFunctionTypes.MAP_SERIES

        partial_fn_type = MetricPartialFunctionTypes(partial_fn_type)
        if partial_fn_type != MetricPartialFunctionTypes.MAP_SERIES:
            raise ValueError(  # noqa: TRY003 # FIXME CoP
                f"""PandasExecutionEngine only supports "{MetricPartialFunctionTypes.MAP_SERIES.value}" for \
"column_function_partial" "partial_fn_type" property."""  # noqa: E501 # FIXME CoP
            )

        def wrapper(metric_fn: Callable):
            @metric_partial(
                engine=engine,
                partial_fn_type=partial_fn_type,  # type: ignore[arg-type] # could be None?
                domain_type=domain_type,
                **kwargs,
            )
            @wraps(metric_fn)
            def inner_func(  # noqa: PLR0913 # FIXME CoP
                cls,
                execution_engine: PandasExecutionEngine,
                metric_domain_kwargs: dict,
                metric_value_kwargs: dict,
                metrics: Dict[str, Any],
                runtime_configuration: dict,
            ):
                metric_domain_kwargs = get_dbms_compatible_metric_domain_kwargs(
                    metric_domain_kwargs=metric_domain_kwargs,
                    batch_columns_list=metrics["table.columns"],
                )

                (
                    df,
                    compute_domain_kwargs,
                    accessor_domain_kwargs,
                ) = execution_engine.get_compute_domain(
                    domain_kwargs=metric_domain_kwargs, domain_type=domain_type
                )

                column_name: Union[str, sqlalchemy.quoted_name] = accessor_domain_kwargs["column"]

                filter_column_isnull = kwargs.get(
                    "filter_column_isnull", getattr(cls, "filter_column_isnull", False)
                )
                if filter_column_isnull:
                    df = df[df[column_name].notnull()]

                values = metric_fn(
                    cls,
                    df[column_name],
                    **metric_value_kwargs,
                    _metrics=metrics,
                )
                return values, compute_domain_kwargs, accessor_domain_kwargs

            return inner_func

        return wrapper

    elif issubclass(engine, SqlAlchemyExecutionEngine):
        if partial_fn_type is None:
            partial_fn_type = MetricPartialFunctionTypes.MAP_FN

        partial_fn_type = MetricPartialFunctionTypes(partial_fn_type)
        if partial_fn_type not in [MetricPartialFunctionTypes.MAP_FN]:
            raise ValueError(  # noqa: TRY003 # FIXME CoP
                f"""SqlAlchemyExecutionEngine only supports "{MetricPartialFunctionTypes.MAP_FN.value}" for \
"column_function_partial" "partial_fn_type" property."""  # noqa: E501 # FIXME CoP
            )

        def wrapper(metric_fn: Callable):
            @metric_partial(
                engine=engine,
                partial_fn_type=partial_fn_type,  # type: ignore[arg-type] # could be None?
                domain_type=domain_type,
                **kwargs,
            )
            @wraps(metric_fn)
            def inner_func(  # noqa: PLR0913 # FIXME CoP
                cls,
                execution_engine: SqlAlchemyExecutionEngine,
                metric_domain_kwargs: dict,
                metric_value_kwargs: dict,
                metrics: Dict[str, Any],
                runtime_configuration: dict,
            ):
                filter_column_isnull = kwargs.get(
                    "filter_column_isnull", getattr(cls, "filter_column_isnull", False)
                )

                metric_domain_kwargs = get_dbms_compatible_metric_domain_kwargs(
                    metric_domain_kwargs=metric_domain_kwargs,
                    batch_columns_list=metrics["table.columns"],
                )

                if filter_column_isnull:
                    compute_domain_kwargs = execution_engine.add_column_row_condition(
                        metric_domain_kwargs
                    )
                else:
                    # We do not copy here because if compute domain is different, it will be copied by get_compute_domain  # noqa: E501 # FIXME CoP
                    compute_domain_kwargs = metric_domain_kwargs

                (
                    selectable,
                    compute_domain_kwargs,
                    accessor_domain_kwargs,
                ) = execution_engine.get_compute_domain(
                    domain_kwargs=compute_domain_kwargs, domain_type=domain_type
                )

                column_name: Union[str, sqlalchemy.quoted_name] = accessor_domain_kwargs["column"]

                dialect = execution_engine.dialect_module
                column_function = metric_fn(
                    cls,
                    sa.column(column_name),
                    **metric_value_kwargs,
                    _dialect=dialect,
                    _table=selectable,
                    _metrics=metrics,
                )
                return column_function, compute_domain_kwargs, accessor_domain_kwargs

            return inner_func

        return wrapper

    elif issubclass(engine, SparkDFExecutionEngine):
        if partial_fn_type is None:
            partial_fn_type = MetricPartialFunctionTypes.MAP_FN

        partial_fn_type = MetricPartialFunctionTypes(partial_fn_type)
        if partial_fn_type not in [
            MetricPartialFunctionTypes.MAP_FN,
            MetricPartialFunctionTypes.WINDOW_FN,
        ]:
            raise ValueError(  # noqa: TRY003 # FIXME CoP
                f"""SparkDFExecutionEngine only supports "{MetricPartialFunctionTypes.MAP_FN.value}" and \
"{MetricPartialFunctionTypes.WINDOW_FN.value}" for "column_function_partial" "partial_fn_type" property."""  # noqa: E501 # FIXME CoP
            )

        def wrapper(metric_fn: Callable):
            @metric_partial(
                engine=engine,
                partial_fn_type=partial_fn_type,
                domain_type=domain_type,
                **kwargs,
            )
            @wraps(metric_fn)
            def inner_func(  # noqa: PLR0913 # FIXME CoP
                cls,
                execution_engine: SparkDFExecutionEngine,
                metric_domain_kwargs: dict,
                metric_value_kwargs: dict,
                metrics: Dict[str, Any],
                runtime_configuration: dict,
            ):
                filter_column_isnull = kwargs.get(
                    "filter_column_isnull", getattr(cls, "filter_column_isnull", False)
                )

                metric_domain_kwargs = get_dbms_compatible_metric_domain_kwargs(
                    metric_domain_kwargs=metric_domain_kwargs,
                    batch_columns_list=metrics["table.columns"],
                )

                if filter_column_isnull:
                    compute_domain_kwargs = execution_engine.add_column_row_condition(
                        metric_domain_kwargs
                    )
                else:
                    # We do not copy here because if compute domain is different, it will be copied by get_compute_domain  # noqa: E501 # FIXME CoP
                    compute_domain_kwargs = metric_domain_kwargs

                (
                    data,
                    compute_domain_kwargs,
                    accessor_domain_kwargs,
                ) = execution_engine.get_compute_domain(
                    domain_kwargs=compute_domain_kwargs, domain_type=domain_type
                )

                column_name: Union[str, sqlalchemy.quoted_name] = accessor_domain_kwargs["column"]

                column = F.col(column_name)
                column_function = metric_fn(
                    cls,
                    column=column,
                    **metric_value_kwargs,
                    _metrics=metrics,
                    _compute_domain_kwargs=compute_domain_kwargs,
                )
                return column_function, compute_domain_kwargs, accessor_domain_kwargs

            return inner_func

        return wrapper

    else:
        raise ValueError(  # noqa: TRY003, TRY004 # FIXME CoP
            'Unsupported engine for "column_function_partial" metric function decorator.'
        )
