from __future__ import annotations

import logging
from functools import wraps
from typing import (
    Any,
    Callable,
    Dict,
    Type,
)

from great_expectations.compatibility.pyspark import functions as F
from great_expectations.compatibility.sqlalchemy import sqlalchemy as sa
from great_expectations.core.metric_domain_types import MetricDomainTypes
from great_expectations.core.metric_function_types import (
    MetricPartialFunctionTypes,
)
from great_expectations.execution_engine import (
    ExecutionEngine,
    PandasExecutionEngine,
    SparkDFExecutionEngine,
    SqlAlchemyExecutionEngine,
)
from great_expectations.expectations.metrics.metric_provider import (
    metric_partial,
)
from great_expectations.expectations.metrics.util import (
    get_dbms_compatible_metric_domain_kwargs,
)

logger = logging.getLogger(__name__)


def column_pair_function_partial(  # noqa: C901 #  16
    engine: Type[ExecutionEngine],
    partial_fn_type: MetricPartialFunctionTypes | None = None,
    **kwargs,
):
    """Provides engine-specific support for authoring a metric_fn with a simplified signature.

    A metric function that is decorated as a column_pair_function_partial will be called with the engine-specific
    column_list type and any value_kwargs associated with the Metric for which the provider function is being declared.

    Args:
        engine: The `ExecutionEngine` used to to evaluate the condition
        partial_fn_type: The metric function
        **kwargs: Arguments passed to specified function

    Returns:
        An annotated metric_function which will be called with a simplified signature.
    """  # noqa: E501 # FIXME CoP
    domain_type = MetricDomainTypes.COLUMN_PAIR
    if issubclass(engine, PandasExecutionEngine):
        if partial_fn_type is None:
            partial_fn_type = MetricPartialFunctionTypes.MAP_SERIES

        partial_fn_type = MetricPartialFunctionTypes(partial_fn_type)
        if partial_fn_type != MetricPartialFunctionTypes.MAP_SERIES:
            raise ValueError(  # noqa: TRY003 # FIXME CoP
                f"""PandasExecutionEngine only supports "{MetricPartialFunctionTypes.MAP_SERIES.value}" for \
"column_pair_function_partial" "partial_fn_type" property."""  # noqa: E501 # FIXME CoP
            )

        def wrapper(metric_fn: Callable):
            assert partial_fn_type is not None  # mypy has trouble type narrowing with closures

            @metric_partial(
                engine=engine,
                partial_fn_type=partial_fn_type,
                domain_type=domain_type,
                **kwargs,
            )
            @wraps(metric_fn)
            def inner_func(  # noqa: PLR0913 # FIXME CoP
                cls,
                execution_engine: PandasExecutionEngine,
                metric_domain_kwargs: dict,
                metric_value_kwargs: dict,
                metrics: Dict[str, Any],
                runtime_configuration: dict,
            ):
                metric_domain_kwargs = get_dbms_compatible_metric_domain_kwargs(
                    metric_domain_kwargs=metric_domain_kwargs,
                    batch_columns_list=metrics["table.columns"],
                )

                (
                    df,
                    compute_domain_kwargs,
                    accessor_domain_kwargs,
                ) = execution_engine.get_compute_domain(
                    domain_kwargs=metric_domain_kwargs, domain_type=domain_type
                )

                # noinspection PyPep8Naming
                column_A_name = accessor_domain_kwargs["column_A"]
                # noinspection PyPep8Naming
                column_B_name = accessor_domain_kwargs["column_B"]

                values = metric_fn(
                    cls,
                    df[column_A_name],
                    df[column_B_name],
                    **metric_value_kwargs,
                    _metrics=metrics,
                )
                return values, compute_domain_kwargs, accessor_domain_kwargs

            return inner_func

        return wrapper

    elif issubclass(engine, SqlAlchemyExecutionEngine):
        if partial_fn_type is None:
            partial_fn_type = MetricPartialFunctionTypes.MAP_FN

        partial_fn_type = MetricPartialFunctionTypes(partial_fn_type)
        if partial_fn_type != MetricPartialFunctionTypes.MAP_FN:
            raise ValueError(  # noqa: TRY003 # FIXME CoP
                f"""SqlAlchemyExecutionEngine only supports "{MetricPartialFunctionTypes.MAP_FN.value}" for \
"column_pair_function_partial" "partial_fn_type" property."""  # noqa: E501 # FIXME CoP
            )

        def wrapper(metric_fn: Callable):
            assert partial_fn_type is not None  # mypy has trouble type narrowing with closures

            @metric_partial(
                engine=engine,
                partial_fn_type=partial_fn_type,
                domain_type=domain_type,
                **kwargs,
            )
            @wraps(metric_fn)
            def inner_func(  # noqa: PLR0913 # FIXME CoP
                cls,
                execution_engine: SqlAlchemyExecutionEngine,
                metric_domain_kwargs: dict,
                metric_value_kwargs: dict,
                metrics: Dict[str, Any],
                runtime_configuration: dict,
            ):
                metric_domain_kwargs = get_dbms_compatible_metric_domain_kwargs(
                    metric_domain_kwargs=metric_domain_kwargs,
                    batch_columns_list=metrics["table.columns"],
                )

                (
                    _selectable,
                    compute_domain_kwargs,
                    accessor_domain_kwargs,
                ) = execution_engine.get_compute_domain(
                    domain_kwargs=metric_domain_kwargs, domain_type=domain_type
                )

                # noinspection PyPep8Naming
                column_A_name = accessor_domain_kwargs["column_A"]
                # noinspection PyPep8Naming
                column_B_name = accessor_domain_kwargs["column_B"]

                column_pair_function = metric_fn(
                    cls,
                    sa.column(column_A_name),
                    sa.column(column_B_name),
                    **metric_value_kwargs,
                    _metrics=metrics,
                )
                return (
                    column_pair_function,
                    compute_domain_kwargs,
                    accessor_domain_kwargs,
                )

            return inner_func

        return wrapper

    elif issubclass(engine, SparkDFExecutionEngine):
        if partial_fn_type is None:
            partial_fn_type = MetricPartialFunctionTypes.MAP_FN

        partial_fn_type = MetricPartialFunctionTypes(partial_fn_type)
        if partial_fn_type != MetricPartialFunctionTypes.MAP_FN:
            raise ValueError(  # noqa: TRY003 # FIXME CoP
                f"""SparkDFExecutionEngine only supports "{MetricPartialFunctionTypes.MAP_FN.value}" for \
"column_pair_function_partial" "partial_fn_type" property."""  # noqa: E501 # FIXME CoP
            )

        def wrapper(metric_fn: Callable):
            @metric_partial(
                engine=engine,
                partial_fn_type=partial_fn_type,
                domain_type=domain_type,
                **kwargs,
            )
            @wraps(metric_fn)
            def inner_func(  # noqa: PLR0913 # FIXME CoP
                cls,
                execution_engine: SparkDFExecutionEngine,
                metric_domain_kwargs: dict,
                metric_value_kwargs: dict,
                metrics: Dict[str, Any],
                runtime_configuration: dict,
            ):
                metric_domain_kwargs = get_dbms_compatible_metric_domain_kwargs(
                    metric_domain_kwargs=metric_domain_kwargs,
                    batch_columns_list=metrics["table.columns"],
                )

                (
                    data,
                    compute_domain_kwargs,
                    accessor_domain_kwargs,
                ) = execution_engine.get_compute_domain(
                    domain_kwargs=metric_domain_kwargs, domain_type=domain_type
                )

                # noinspection PyPep8Naming
                column_A_name = accessor_domain_kwargs["column_A"]
                # noinspection PyPep8Naming
                column_B_name = accessor_domain_kwargs["column_B"]

                column_pair_function = metric_fn(
                    cls,
                    F.col(column_A_name),
                    F.col(column_B_name),
                    **metric_value_kwargs,
                    _metrics=metrics,
                )
                return (
                    column_pair_function,
                    compute_domain_kwargs,
                    accessor_domain_kwargs,
                )

            return inner_func

        return wrapper

    else:
        raise ValueError(  # noqa: TRY003, TRY004 # FIXME CoP
            'Unsupported engine for "column_pair_function_partial" metric function decorator.'
        )
