# -*- coding: utf-8 -*-
# *******************************************************
#   ____                     _               _
#  / ___|___  _ __ ___   ___| |_   _ __ ___ | |
# | |   / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| |
# | |__| (_) | | | | | |  __/ |_ _| | | | | | |
#  \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_|
#
#  Sign up for free at https://www.comet.com
#  Copyright (C) 2015-2024 Comet ML INC
#  This source code is licensed under the MIT license.
# *******************************************************
import copy
import logging

from ..config import MAXIMAL_VALUE_LENGTH
from ..convert_utils import convert_to_string_value
from ..logging_messages import SKLEARN_INTEGRATION_PIPELINE_ERROR
from ..messages import ParameterMessage
from ..monkey_patching import check_module

LOGGER = logging.getLogger(__name__)


def pre_process_params(params):
    try:
        if "random_state" in params:
            our_params = copy.copy(params)
            if our_params["random_state"].__class__.__name__ == "RandomState":
                del our_params["random_state"]
                return our_params

    except Exception as e:
        LOGGER.info(
            "failed to remove RandomState from sklearn object with error %s",
            e,
            exc_info=True,
        )

    return params


def _preprocess_pipelines_in_estimator_params(params):
    if params is not None and "steps" in params:
        for step in params["steps"]:
            step_name, _ = step
            if step_name in params:
                del params[step_name]
        del params["steps"]

    # the grid search
    if "estimator__steps" in params:
        params["estimator__steps"] = convert_to_string_value(
            params["estimator__steps"], max_length=MAXIMAL_VALUE_LENGTH
        )

    if "estimator" in params:
        params["estimator"] = convert_to_string_value(
            params["estimator"], max_length=MAXIMAL_VALUE_LENGTH
        )

    return params


def _log_estimator_params(experiment, estimator):
    if experiment.auto_param_logging:
        try:
            params = estimator.get_params()
            params = _preprocess_pipelines_in_estimator_params(params)
            processed_params = pre_process_params(params)
            if len(processed_params) == 0:
                LOGGER.debug(
                    "Skipping empty params %r from Estimator %r",
                    processed_params,
                    estimator,
                )
                return

            experiment._log_parameters(
                processed_params,
                framework="scikit-learn",
                source=ParameterMessage.source_autologger,
                flatten_nested=False,
            )
        except Exception:
            LOGGER.error("Failed to extract parameters from estimator", exc_info=True)


def fit_logger_before(experiment, original, *args, **kwargs):
    _log_estimator_params(experiment, args[0])


def fit_logger_after(experiment, original, ret_val, *args, **kwargs):
    _log_estimator_params(experiment, ret_val)


def _log_pipeline_params(experiment, pipeline):
    if not experiment.auto_param_logging:
        return

    try:
        params = pipeline.get_params()
        if params is not None and "steps" in params:
            for step in params["steps"]:
                step_name, step_mdl = step

                if not hasattr(step_mdl, "get_params"):
                    LOGGER.warning(
                        SKLEARN_INTEGRATION_PIPELINE_ERROR.format(step_name, step_mdl)
                    )
                    continue

                params = step_mdl.get_params()
                processed_params = pre_process_params(params)
                experiment._log_parameters(
                    processed_params,
                    prefix=step_name,
                    framework="scikit-learn",
                    source=ParameterMessage.source_autologger,
                    flatten_nested=False,
                )
    except Exception:
        LOGGER.error("Failed to extract parameters from Pipeline", exc_info=True)


def pipeline_fit_logger_before(experiment, original, *args, **kwargs):
    _log_pipeline_params(experiment, args[0])


def pipeline_fit_logger_after(experiment, original, ret_val, *args, **kwargs):
    _log_pipeline_params(experiment, ret_val)


# List generated by merging its previous version with an output of
# utils/generate_sklearn_entrypoints.py for Scikit-versions:
# * 0.20.4
# * 0.21.3
# * 0.22.2.post1
# * 0.23.2
# * 0.24.2
# * 1.0.2
# * 1.1.0
# * 1.3.0
# * 1.5.0
FIT_MODULES = [
    ("sklearn.linear_model._logistic", "LogisticRegressionCV.fit"),
    ("sklearn.model_selection._search", "GridSearchCV.fit"),
    ("sklearn.gaussian_process._gpr", "GaussianProcessRegressor.fit"),
    ("sklearn.preprocessing.data", "RobustScaler.fit"),
    ("sklearn.semi_supervised._label_propagation", "LabelSpreading.fit"),
    ("sklearn.feature_extraction._dict_vectorizer", "DictVectorizer.fit"),
    ("sklearn.covariance.shrunk_covariance_", "LedoitWolf.fit"),
    ("sklearn.preprocessing._data", "QuantileTransformer.fit"),
    ("sklearn.neighbors.regression", "KNeighborsRegressor.fit"),
    ("sklearn.linear_model.logistic", "LogisticRegression.fit"),
    ("sklearn.preprocessing._label", "LabelEncoder.fit"),
    ("sklearn.linear_model._coordinate_descent", "ElasticNetCV.fit"),
    ("sklearn.linear_model.coordinate_descent", "MultiTaskElasticNet.fit"),
    ("sklearn.linear_model._least_angle", "Lars.fit"),
    ("sklearn.neighbors.approximate", "GaussianRandomProjectionHash.fit"),
    ("sklearn.linear_model.randomized_l1", "RandomizedLogisticRegression.fit"),
    ("sklearn.preprocessing._function_transformer", "FunctionTransformer.fit"),
    ("sklearn.manifold._locally_linear", "LocallyLinearEmbedding.fit"),
    ("sklearn.cluster.k_means_", "KMeans.fit"),
    ("sklearn.tree.tree", "DecisionTreeRegressor.fit"),
    ("sklearn.linear_model._ridge", "RidgeClassifier.fit"),
    ("sklearn.decomposition._nmf", "_BaseNMF.fit"),
    ("sklearn.preprocessing.data", "Binarizer.fit"),
    ("sklearn.decomposition._nmf", "MiniBatchNMF.fit"),
    ("sklearn.preprocessing._data", "Normalizer.fit"),
    ("sklearn.feature_selection._from_model", "SelectFromModel.fit"),
    ("sklearn.linear_model.logistic", "LogisticRegressionCV.fit"),
    ("sklearn.neighbors._kde", "KernelDensity.fit"),
    ("sklearn.discriminant_analysis", "QuadraticDiscriminantAnalysis.fit"),
    ("sklearn.linear_model.ridge", "_BaseRidgeCV.fit"),
    ("sklearn.linear_model.stochastic_gradient", "SGDRegressor.fit"),
    ("sklearn.linear_model.coordinate_descent", "Lasso.fit"),
    ("sklearn.neighbors.nca", "NeighborhoodComponentsAnalysis.fit"),
    ("sklearn.cross_decomposition.pls_", "PLSSVD.fit"),
    ("sklearn.ensemble._gb", "GradientBoostingRegressor.fit"),
    ("sklearn.decomposition._pca", "PCA.fit"),
    ("sklearn.decomposition._dict_learning", "DictionaryLearning.fit"),
    ("sklearn.preprocessing._data", "RobustScaler.fit"),
    ("sklearn.semi_supervised._self_training", "SelfTrainingClassifier.fit"),
    ("sklearn.covariance._shrunk_covariance", "ShrunkCovariance.fit"),
    ("sklearn.linear_model.least_angle", "LassoLarsCV.fit"),
    ("sklearn.decomposition.dict_learning", "DictionaryLearning.fit"),
    ("sklearn.covariance._graph_lasso", "GraphicalLasso.fit"),
    ("sklearn.ensemble._gb", "GradientBoostingClassifier.fit"),
    ("sklearn.linear_model._least_angle", "LassoLars.fit"),
    ("sklearn.decomposition.incremental_pca", "IncrementalPCA.fit"),
    ("sklearn.svm._classes", "NuSVR.fit"),
    ("sklearn.linear_model.ridge", "RidgeCV.fit"),
    ("sklearn.cluster.hierarchical", "AgglomerativeClustering.fit"),
    ("sklearn.linear_model.coordinate_descent", "LassoCV.fit"),
    ("sklearn.ensemble.forest", "RandomForestRegressor.fit"),
    ("sklearn.linear_model._least_angle", "LarsCV.fit"),
    ("sklearn.manifold._t_sne", "TSNE.fit"),
    ("sklearn.manifold.mds", "MDS.fit"),
    ("sklearn.linear_model.coordinate_descent", "ElasticNet.fit"),
    ("sklearn.pipeline", "FeatureUnion.fit"),
    ("sklearn.tree._classes", "ExtraTreeRegressor.fit"),
    ("sklearn.gaussian_process._gpc", "_BinaryGaussianProcessClassifierLaplace.fit"),
    ("sklearn.isotonic", "IsotonicRegression.fit"),
    ("sklearn.linear_model._omp", "OrthogonalMatchingPursuit.fit"),
    ("sklearn.decomposition._kernel_pca", "KernelPCA.fit"),
    ("sklearn.feature_selection.univariate_selection", "SelectPercentile.fit"),
    ("sklearn.compose._column_transformer", "ColumnTransformer.fit"),
    ("sklearn.feature_selection._rfe", "RFE.fit"),
    ("sklearn.neighbors._nearest_centroid", "NearestCentroid.fit"),
    ("sklearn.ensemble.voting", "VotingClassifier.fit"),
    ("sklearn.multioutput", "MultiOutputRegressor.fit"),
    ("sklearn.dummy", "DummyRegressor.fit"),
    ("sklearn.cross_decomposition.cca_", "CCA.fit"),
    ("sklearn.ensemble.forest", "RandomTreesEmbedding.fit"),
    ("sklearn.impute", "SimpleImputer.fit"),
    ("sklearn.multiclass", "OneVsRestClassifier.fit"),
    ("sklearn.tree._classes", "DecisionTreeRegressor.fit"),
    ("sklearn.cluster._agglomerative", "AgglomerativeClustering.fit"),
    ("sklearn.decomposition._lda", "LatentDirichletAllocation.fit"),
    ("sklearn.linear_model.coordinate_descent", "MultiTaskLassoCV.fit"),
    ("sklearn.decomposition.sparse_pca", "SparsePCA.fit"),
    ("sklearn.linear_model.ridge", "Ridge.fit"),
    ("sklearn.kernel_approximation", "SkewedChi2Sampler.fit"),
    ("sklearn.cross_decomposition._pls", "PLSCanonical.fit"),
    ("sklearn.preprocessing.label", "MultiLabelBinarizer.fit"),
    ("sklearn.kernel_approximation", "PolynomialCountSketch.fit"),
    ("sklearn.linear_model._stochastic_gradient", "SGDOneClassSVM.fit"),
    ("sklearn.neighbors._unsupervised", "NearestNeighbors.fit"),
    ("sklearn.decomposition._dict_learning", "MiniBatchDictionaryLearning.fit"),
    ("sklearn.tree._classes", "ExtraTreeClassifier.fit"),
    ("sklearn.cluster._mean_shift", "MeanShift.fit"),
    ("sklearn.linear_model.ransac", "RANSACRegressor.fit"),
    ("sklearn.random_projection", "GaussianRandomProjection.fit"),
    ("sklearn.decomposition.truncated_svd", "TruncatedSVD.fit"),
    ("sklearn.cluster._birch", "Birch.fit"),
    ("sklearn.mixture._gaussian_mixture", "GaussianMixture.fit"),
    ("sklearn.linear_model._stochastic_gradient", "SGDClassifier.fit"),
    ("sklearn.preprocessing._encoders", "OrdinalEncoder.fit"),
    ("sklearn.neural_network.multilayer_perceptron", "MLPClassifier.fit"),
    ("sklearn.preprocessing._target_encoder", "TargetEncoder.fit"),
    ("sklearn.cross_decomposition.pls_", "PLSRegression.fit"),
    ("sklearn.linear_model._ridge", "_RidgeGCV.fit"),
    (
        "sklearn.model_selection._classification_threshold",
        "TunedThresholdClassifierCV.fit",
    ),
    ("sklearn.cluster._bicluster", "SpectralBiclustering.fit"),
    ("sklearn.decomposition.online_lda", "LatentDirichletAllocation.fit"),
    ("sklearn.ensemble.voting_classifier", "VotingClassifier.fit"),
    ("sklearn.linear_model._coordinate_descent", "LassoCV.fit"),
    (
        "sklearn.model_selection._classification_threshold",
        "FixedThresholdClassifier.fit",
    ),
    ("sklearn.feature_selection.variance_threshold", "VarianceThreshold.fit"),
    ("sklearn.linear_model._glm.glm", "TweedieRegressor.fit"),
    ("sklearn.covariance.graph_lasso_", "GraphLasso.fit"),
    ("sklearn.linear_model.least_angle", "LassoLarsIC.fit"),
    ("sklearn.linear_model._coordinate_descent", "MultiTaskElasticNetCV.fit"),
    ("sklearn.tree.tree", "ExtraTreeRegressor.fit"),
    ("sklearn.ensemble.forest", "ExtraTreesClassifier.fit"),
    ("sklearn.naive_bayes", "MultinomialNB.fit"),
    ("sklearn.multiclass", "OneVsOneClassifier.fit"),
    (
        "sklearn.ensemble._hist_gradient_boosting.gradient_boosting",
        "HistGradientBoostingRegressor.fit",
    ),
    ("sklearn.ensemble._forest", "ExtraTreesRegressor.fit"),
    ("sklearn.covariance.shrunk_covariance_", "OAS.fit"),
    ("sklearn.calibration", "_SigmoidCalibration.fit"),
    ("sklearn.preprocessing._discretization", "KBinsDiscretizer.fit"),
    ("sklearn.mixture.gaussian_mixture", "GaussianMixture.fit"),
    ("sklearn.decomposition.sparse_pca", "MiniBatchSparsePCA.fit"),
    ("sklearn.ensemble.weight_boosting", "AdaBoostClassifier.fit"),
    ("sklearn.svm._classes", "OneClassSVM.fit"),
    ("sklearn.dummy", "DummyClassifier.fit"),
    ("sklearn.feature_selection._univariate_selection", "SelectFdr.fit"),
    ("sklearn.linear_model.perceptron", "Perceptron.fit"),
    ("sklearn.svm.classes", "OneClassSVM.fit"),
    ("sklearn.gaussian_process._gpc", "GaussianProcessClassifier.fit"),
    ("sklearn.compose._target", "TransformedTargetRegressor.fit"),
    ("sklearn.cluster._bisect_k_means", "BisectingKMeans.fit"),
    ("sklearn.feature_selection._univariate_selection", "SelectFwe.fit"),
    ("sklearn.linear_model._coordinate_descent", "MultiTaskLassoCV.fit"),
    ("sklearn.preprocessing._label", "LabelBinarizer.fit"),
    ("sklearn.linear_model.omp", "OrthogonalMatchingPursuit.fit"),
    ("sklearn.decomposition._dict_learning", "SparseCoder.fit"),
    ("sklearn.neighbors._regression", "RadiusNeighborsRegressor.fit"),
    ("sklearn.linear_model._ridge", "_BaseRidgeCV.fit"),
    ("sklearn.linear_model.theil_sen", "TheilSenRegressor.fit"),
    ("sklearn.manifold._spectral_embedding", "SpectralEmbedding.fit"),
    ("sklearn.decomposition.fastica_", "FastICA.fit"),
    ("sklearn.linear_model.huber", "HuberRegressor.fit"),
    ("sklearn.decomposition.dict_learning", "SparseCoder.fit"),
    ("sklearn.neighbors.classification", "KNeighborsClassifier.fit"),
    ("sklearn.feature_extraction.dict_vectorizer", "DictVectorizer.fit"),
    ("sklearn.linear_model._omp", "OrthogonalMatchingPursuitCV.fit"),
    ("sklearn.naive_bayes", "BernoulliNB.fit"),
    ("sklearn.tree.tree", "ExtraTreeClassifier.fit"),
    ("sklearn.cluster._hdbscan.hdbscan", "HDBSCAN.fit"),
    ("sklearn.feature_selection._rfe", "RFECV.fit"),
    ("sklearn.cluster.affinity_propagation_", "AffinityPropagation.fit"),
    ("sklearn.feature_extraction.text", "HashingVectorizer.fit"),
    ("sklearn.gaussian_process.gpc", "_BinaryGaussianProcessClassifierLaplace.fit"),
    ("sklearn.preprocessing.imputation", "Imputer.fit"),
    ("sklearn.preprocessing.label", "LabelEncoder.fit"),
    ("sklearn.cluster._dbscan", "DBSCAN.fit"),
    ("sklearn.covariance.graph_lasso_", "GraphicalLassoCV.fit"),
    ("sklearn.feature_extraction.hashing", "FeatureHasher.fit"),
    ("sklearn.random_projection", "SparseRandomProjection.fit"),
    ("sklearn.svm._classes", "SVR.fit"),
    ("sklearn.linear_model._stochastic_gradient", "SGDRegressor.fit"),
    ("sklearn.covariance._empirical_covariance", "EmpiricalCovariance.fit"),
    ("sklearn.preprocessing.data", "KernelCenterer.fit"),
    ("sklearn.neighbors._classification", "RadiusNeighborsClassifier.fit"),
    ("sklearn.svm.classes", "NuSVR.fit"),
    ("sklearn.svm.classes", "SVR.fit"),
    ("sklearn.semi_supervised.label_propagation", "LabelSpreading.fit"),
    ("sklearn.ensemble.voting", "VotingRegressor.fit"),
    ("sklearn.preprocessing._polynomial", "PolynomialFeatures.fit"),
    ("sklearn.svm._classes", "LinearSVC.fit"),
    ("sklearn.feature_selection.rfe", "RFE.fit"),
    ("sklearn.cluster._kmeans", "MiniBatchKMeans.fit"),
    ("sklearn.feature_extraction.image", "PatchExtractor.fit"),
    ("sklearn.preprocessing.label", "LabelBinarizer.fit"),
    ("sklearn.feature_selection.univariate_selection", "SelectFwe.fit"),
    ("sklearn.covariance.elliptic_envelope", "EllipticEnvelope.fit"),
    ("sklearn.svm.classes", "LinearSVC.fit"),
    ("sklearn.neighbors.regression", "RadiusNeighborsRegressor.fit"),
    ("sklearn.linear_model._least_angle", "LassoLarsIC.fit"),
    ("sklearn.preprocessing._polynomial", "SplineTransformer.fit"),
    ("sklearn.preprocessing._data", "Binarizer.fit"),
    ("sklearn.feature_selection._univariate_selection", "SelectFpr.fit"),
    ("sklearn.feature_selection._univariate_selection", "SelectKBest.fit"),
    ("sklearn.covariance.empirical_covariance_", "EmpiricalCovariance.fit"),
    ("sklearn.preprocessing.data", "StandardScaler.fit"),
    ("sklearn.ensemble._hist_gradient_boosting.binning", "_BinMapper.fit"),
    ("sklearn.ensemble._bagging", "BaggingClassifier.fit"),
    ("sklearn.ensemble.bagging", "BaggingClassifier.fit"),
    ("sklearn.decomposition.kernel_pca", "KernelPCA.fit"),
    ("sklearn.manifold.locally_linear", "LocallyLinearEmbedding.fit"),
    ("sklearn.decomposition._sparse_pca", "SparsePCA.fit"),
    ("sklearn.cluster._bicluster", "SpectralCoclustering.fit"),
    ("sklearn.covariance._shrunk_covariance", "LedoitWolf.fit"),
    ("sklearn.decomposition.factor_analysis", "FactorAnalysis.fit"),
    ("sklearn.gaussian_process.gpr", "GaussianProcessRegressor.fit"),
    ("sklearn.impute", "MissingIndicator.fit"),
    ("sklearn.covariance.robust_covariance", "MinCovDet.fit"),
    ("sklearn.linear_model._ridge", "RidgeClassifierCV.fit"),
    ("sklearn.neighbors._graph", "KNeighborsTransformer.fit"),
    ("sklearn.impute._base", "MissingIndicator.fit"),
    ("sklearn.linear_model._coordinate_descent", "MultiTaskElasticNet.fit"),
    ("sklearn.decomposition._nmf", "NMF.fit"),
    ("sklearn.cluster.spectral", "SpectralClustering.fit"),
    ("sklearn.manifold.spectral_embedding_", "SpectralEmbedding.fit"),
    ("sklearn.covariance._graph_lasso", "BaseGraphicalLasso.fit"),
    ("sklearn.cross_decomposition._cca", "CCA.fit"),
    ("sklearn.ensemble.weight_boosting", "AdaBoostRegressor.fit"),
    ("sklearn.svm._classes", "LinearSVR.fit"),
    ("sklearn.cluster.bicluster", "SpectralBiclustering.fit"),
    ("sklearn.feature_selection._sequential", "SequentialFeatureSelector.fit"),
    ("sklearn.preprocessing.data", "MaxAbsScaler.fit"),
    ("sklearn.cluster.k_means_", "MiniBatchKMeans.fit"),
    ("sklearn.svm.classes", "LinearSVR.fit"),
    ("sklearn.preprocessing.data", "PolynomialFeatures.fit"),
    ("sklearn.kernel_approximation", "Nystroem.fit"),
    ("sklearn.svm._classes", "NuSVC.fit"),
    ("sklearn.cluster._spectral", "SpectralClustering.fit"),
    ("sklearn.ensemble._forest", "RandomForestRegressor.fit"),
    ("sklearn.cross_decomposition._pls", "PLSSVD.fit"),
    ("sklearn.linear_model._coordinate_descent", "Lasso.fit"),
    ("sklearn.decomposition.dict_learning", "MiniBatchDictionaryLearning.fit"),
    ("sklearn.feature_selection.univariate_selection", "SelectFpr.fit"),
    ("sklearn.linear_model.passive_aggressive", "PassiveAggressiveClassifier.fit"),
    ("sklearn.linear_model._least_angle", "LassoLarsCV.fit"),
    ("sklearn.svm.classes", "NuSVC.fit"),
    ("sklearn.linear_model._theil_sen", "TheilSenRegressor.fit"),
    ("sklearn.linear_model.least_angle", "Lars.fit"),
    ("sklearn.ensemble._weight_boosting", "AdaBoostClassifier.fit"),
    ("sklearn.linear_model._perceptron", "Perceptron.fit"),
    ("sklearn.linear_model._passive_aggressive", "PassiveAggressiveClassifier.fit"),
    ("sklearn.neighbors.nearest_centroid", "NearestCentroid.fit"),
    ("sklearn.multioutput", "MultiOutputClassifier.fit"),
    ("sklearn.linear_model.coordinate_descent", "MultiTaskLasso.fit"),
    ("sklearn.linear_model._huber", "HuberRegressor.fit"),
    ("sklearn.multioutput", "ClassifierChain.fit"),
    ("sklearn.ensemble._forest", "RandomTreesEmbedding.fit"),
    ("sklearn.linear_model._coordinate_descent", "ElasticNet.fit"),
    ("sklearn.linear_model.bayes", "BayesianRidge.fit"),
    ("sklearn.naive_bayes", "CategoricalNB.fit"),
    ("sklearn.decomposition._factor_analysis", "FactorAnalysis.fit"),
    ("sklearn.covariance.shrunk_covariance_", "ShrunkCovariance.fit"),
    ("sklearn.neural_network.rbm", "BernoulliRBM.fit"),
    ("sklearn.ensemble.bagging", "BaggingRegressor.fit"),
    ("sklearn.cluster.hierarchical", "FeatureAgglomeration.fit"),
    ("sklearn.linear_model.ridge", "_RidgeGCV.fit"),
    ("sklearn.multiclass", "_ConstantPredictor.fit"),
    ("sklearn.decomposition._fastica", "FastICA.fit"),
    ("sklearn.decomposition._sparse_pca", "MiniBatchSparsePCA.fit"),
    ("sklearn.feature_extraction.text", "TfidfTransformer.fit"),
    ("sklearn.gaussian_process.gpc", "GaussianProcessClassifier.fit"),
    ("sklearn.linear_model.base", "LinearRegression.fit"),
    ("sklearn.feature_extraction.text", "TfidfVectorizer.fit"),
    ("sklearn.manifold.t_sne", "TSNE.fit"),
    ("sklearn.linear_model.randomized_l1", "RandomizedLasso.fit"),
    ("sklearn.preprocessing._data", "PolynomialFeatures.fit"),
    ("sklearn.linear_model.omp", "OrthogonalMatchingPursuitCV.fit"),
    ("sklearn.kernel_ridge", "KernelRidge.fit"),
    ("sklearn.tree._classes", "DecisionTreeClassifier.fit"),
    ("sklearn.neighbors._classification", "KNeighborsClassifier.fit"),
    ("sklearn.neural_network._multilayer_perceptron", "MLPRegressor.fit"),
    ("sklearn.cluster._kmeans", "KMeans.fit"),
    ("sklearn.neighbors._graph", "RadiusNeighborsTransformer.fit"),
    ("sklearn.linear_model.bayes", "ARDRegression.fit"),
    ("sklearn.cluster.optics_", "OPTICS.fit"),
    ("sklearn.decomposition._sparse_pca", "_BaseSparsePCA.fit"),
    ("sklearn.cross_decomposition._pls", "CCA.fit"),
    ("sklearn.impute._knn", "KNNImputer.fit"),
    ("sklearn.preprocessing.data", "Normalizer.fit"),
    ("sklearn.decomposition._incremental_pca", "IncrementalPCA.fit"),
    ("sklearn.covariance.graph_lasso_", "GraphLassoCV.fit"),
    ("sklearn.manifold._mds", "MDS.fit"),
    ("sklearn.feature_extraction._hash", "FeatureHasher.fit"),
    ("sklearn.feature_selection.rfe", "RFECV.fit"),
    ("sklearn.cluster._affinity_propagation", "AffinityPropagation.fit"),
    ("sklearn.ensemble._iforest", "IsolationForest.fit"),
    (
        "sklearn.model_selection._classification_threshold",
        "BaseThresholdClassifier.fit",
    ),
    ("sklearn.cluster._agglomerative", "FeatureAgglomeration.fit"),
    ("sklearn.neural_network._multilayer_perceptron", "MLPClassifier.fit"),
    ("sklearn.neighbors.kde", "KernelDensity.fit"),
    ("sklearn.feature_selection._univariate_selection", "GenericUnivariateSelect.fit"),
    ("sklearn.ensemble._voting", "VotingRegressor.fit"),
    ("sklearn.calibration", "CalibratedClassifierCV.fit"),
    ("sklearn.multioutput", "RegressorChain.fit"),
    ("sklearn.preprocessing.data", "MinMaxScaler.fit"),
    ("sklearn.semi_supervised._label_propagation", "LabelPropagation.fit"),
    ("sklearn.cross_decomposition._pls", "PLSRegression.fit"),
    ("sklearn.cross_decomposition.pls_", "PLSCanonical.fit"),
    ("sklearn.neighbors.unsupervised", "NearestNeighbors.fit"),
    ("sklearn.linear_model.ridge", "RidgeClassifier.fit"),
    ("sklearn.covariance._elliptic_envelope", "EllipticEnvelope.fit"),
    ("sklearn.neighbors.lof", "LocalOutlierFactor.fit"),
    ("sklearn.linear_model._bayes", "BayesianRidge.fit"),
    ("sklearn.linear_model._ridge", "RidgeCV.fit"),
    ("sklearn.feature_selection._univariate_selection", "SelectPercentile.fit"),
    ("sklearn.mixture._bayesian_mixture", "BayesianGaussianMixture.fit"),
    ("sklearn.ensemble._voting", "VotingClassifier.fit"),
    ("sklearn.covariance._robust_covariance", "MinCovDet.fit"),
    ("sklearn.preprocessing._encoders", "OneHotEncoder.fit"),
    ("sklearn.ensemble._forest", "ExtraTreesClassifier.fit"),
    ("sklearn.cluster.bicluster", "SpectralCoclustering.fit"),
    ("sklearn.feature_selection._variance_threshold", "VarianceThreshold.fit"),
    ("sklearn.feature_selection.univariate_selection", "SelectFdr.fit"),
    ("sklearn.linear_model._coordinate_descent", "MultiTaskLasso.fit"),
    ("sklearn.linear_model._glm.glm", "PoissonRegressor.fit"),
    ("sklearn.ensemble._weight_boosting", "AdaBoostRegressor.fit"),
    ("sklearn.neighbors.approximate", "LSHForest.fit"),
    ("sklearn.linear_model._bayes", "ARDRegression.fit"),
    ("sklearn.svm._classes", "SVC.fit"),
    ("sklearn.preprocessing.data", "PowerTransformer.fit"),
    ("sklearn.manifold.isomap", "Isomap.fit"),
    ("sklearn.mixture.bayesian_mixture", "BayesianGaussianMixture.fit"),
    ("sklearn.svm.classes", "SVC.fit"),
    ("sklearn.pipeline", "Pipeline.fit"),
    ("sklearn.feature_selection.univariate_selection", "GenericUnivariateSelect.fit"),
    ("sklearn.covariance._shrunk_covariance", "OAS.fit"),
    ("sklearn.preprocessing._data", "MinMaxScaler.fit"),
    ("sklearn.preprocessing._label", "MultiLabelBinarizer.fit"),
    ("sklearn.tree.tree", "DecisionTreeClassifier.fit"),
    ("sklearn.cluster.mean_shift_", "MeanShift.fit"),
    ("sklearn.manifold._isomap", "Isomap.fit"),
    (
        "sklearn.ensemble._hist_gradient_boosting.gradient_boosting",
        "HistGradientBoostingClassifier.fit",
    ),
    ("sklearn.linear_model.coordinate_descent", "MultiTaskElasticNetCV.fit"),
    ("sklearn.linear_model._ridge", "Ridge.fit"),
    ("sklearn.preprocessing._data", "KernelCenterer.fit"),
    ("sklearn.semi_supervised.label_propagation", "LabelPropagation.fit"),
    ("sklearn.feature_extraction.text", "CountVectorizer.fit"),
    ("sklearn.discriminant_analysis", "LinearDiscriminantAnalysis.fit"),
    ("sklearn.ensemble.forest", "RandomForestClassifier.fit"),
    ("sklearn.cluster.dbscan_", "DBSCAN.fit"),
    ("sklearn.ensemble.iforest", "IsolationForest.fit"),
    ("sklearn.ensemble.gradient_boosting", "GradientBoostingRegressor.fit"),
    ("sklearn.cluster._optics", "OPTICS.fit"),
    ("sklearn.preprocessing._data", "PowerTransformer.fit"),
    ("sklearn.linear_model.coordinate_descent", "ElasticNetCV.fit"),
    ("sklearn.linear_model._base", "LinearRegression.fit"),
    ("sklearn.neural_network.multilayer_perceptron", "MLPRegressor.fit"),
    ("sklearn.impute._base", "SimpleImputer.fit"),
    ("sklearn.preprocessing._data", "StandardScaler.fit"),
    ("sklearn.linear_model.stochastic_gradient", "SGDClassifier.fit"),
    ("sklearn.decomposition.pca", "PCA.fit"),
    ("sklearn.feature_selection.univariate_selection", "SelectKBest.fit"),
    ("sklearn.linear_model.ridge", "RidgeClassifierCV.fit"),
    ("sklearn.neighbors.classification", "RadiusNeighborsClassifier.fit"),
    ("sklearn.ensemble.gradient_boosting", "GradientBoostingClassifier.fit"),
    ("sklearn.feature_selection.from_model", "SelectFromModel.fit"),
    ("sklearn.multiclass", "OutputCodeClassifier.fit"),
    ("sklearn.covariance._graph_lasso", "GraphicalLassoCV.fit"),
    ("sklearn.ensemble.forest", "ExtraTreesRegressor.fit"),
    ("sklearn.naive_bayes", "ComplementNB.fit"),
    ("sklearn.linear_model.passive_aggressive", "PassiveAggressiveRegressor.fit"),
    ("sklearn.linear_model._glm.glm", "GeneralizedLinearRegressor.fit"),
    ("sklearn.ensemble._stacking", "StackingClassifier.fit"),
    ("sklearn.neighbors._nca", "NeighborhoodComponentsAnalysis.fit"),
    ("sklearn.kernel_approximation", "AdditiveChi2Sampler.fit"),
    ("sklearn.linear_model._glm.glm", "_GeneralizedLinearRegressor.fit"),
    ("sklearn.linear_model._passive_aggressive", "PassiveAggressiveRegressor.fit"),
    ("sklearn.decomposition._truncated_svd", "TruncatedSVD.fit"),
    ("sklearn.decomposition.nmf", "NMF.fit"),
    ("sklearn.naive_bayes", "GaussianNB.fit"),
    ("sklearn.cluster.birch", "Birch.fit"),
    ("sklearn.neighbors._lof", "LocalOutlierFactor.fit"),
    ("sklearn.preprocessing.data", "QuantileTransformer.fit"),
    ("sklearn.neural_network._rbm", "BernoulliRBM.fit"),
    ("sklearn.ensemble._bagging", "BaggingRegressor.fit"),
    ("sklearn.linear_model._quantile", "QuantileRegressor.fit"),
    ("sklearn.neighbors._regression", "KNeighborsRegressor.fit"),
    ("sklearn.preprocessing._data", "MaxAbsScaler.fit"),
    ("sklearn.linear_model._logistic", "LogisticRegression.fit"),
    ("sklearn.linear_model.least_angle", "LassoLars.fit"),
    ("sklearn.covariance.graph_lasso_", "GraphicalLasso.fit"),
    ("sklearn.linear_model._ransac", "RANSACRegressor.fit"),
    ("sklearn.ensemble._forest", "RandomForestClassifier.fit"),
    ("sklearn.model_selection._search", "RandomizedSearchCV.fit"),
    ("sklearn.linear_model.least_angle", "LarsCV.fit"),
    ("sklearn.linear_model._glm.glm", "GammaRegressor.fit"),
    ("sklearn.ensemble._stacking", "StackingRegressor.fit"),
    ("sklearn.kernel_approximation", "RBFSampler.fit"),
    ("sklearn.frozen._frozen", "FrozenEstimator.fit"),
]


PIPELINE_FIT_MODULES = [("sklearn.pipeline", "Pipeline.fit")]


def patch(module_finder):
    check_module("sklearn")

    # Register the pipeline fit methods
    for module, object_name in PIPELINE_FIT_MODULES:
        module_finder.register_before(module, object_name, pipeline_fit_logger_before)
        module_finder.register_after(module, object_name, pipeline_fit_logger_after)

    # Register the fit methods
    for module, object_name in FIT_MODULES:
        module_finder.register_before(module, object_name, fit_logger_before)
        module_finder.register_after(module, object_name, fit_logger_after)


check_module("sklearn")
