/*
 * Decompiled with CFR 0.152.
 */
package pycaret.preprocess;

import category_encoders.BaseEncoder;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.dmg.pmml.Field;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FeatureUtil;
import org.jpmml.converter.Label;
import org.jpmml.converter.ScalarLabel;
import org.jpmml.python.ClassDictUtil;
import org.jpmml.sklearn.SkLearnEncoder;
import pycaret.preprocess.FixImbalancer;
import pycaret.preprocess.RareCategoryGrouping;
import sklearn.Initializer;
import sklearn.InitializerUtil;
import sklearn.Transformer;
import sklearn.impute.SimpleImputer;

public class TransformerWrapper
extends Initializer {
    public TransformerWrapper(String module, String name) {
        super(module, name);
    }

    public int getNumberOfFeatures() {
        List<String> featureNames = this.getFeatureNames();
        return featureNames.size();
    }

    public void checkFeatures(List<? extends Feature> features) {
        if (!features.isEmpty()) {
            super.checkFeatures(features);
        }
    }

    public List<Feature> initializeFeatures(SkLearnEncoder encoder) {
        return this.encodeFeatures(Collections.emptyList(), encoder);
    }

    public List<Feature> encodeFeatures(List<Feature> features, SkLearnEncoder encoder) {
        ScalarLabel scalarLabel;
        Feature labelFeature;
        List<String> featureNames = this.getFeatureNames();
        List<String> include = this.getInclude();
        Boolean trainOnly = this.getTrainOnly();
        Transformer transformer = this.getTransformer();
        if (trainOnly.booleanValue()) {
            return features;
        }
        if (features.isEmpty()) {
            features = InitializerUtil.selectFeatures(featureNames, features, (SkLearnEncoder)encoder);
        }
        ArrayList<Feature> includeFeatures = new ArrayList<Feature>();
        for (int i = 0; i < include.size(); ++i) {
            Feature includeFeature;
            String includeColumn = include.get(i);
            if (!features.isEmpty()) {
                int index = featureNames.indexOf(includeColumn);
                includeFeature = (Feature)features.get(index);
            } else {
                includeFeature = InitializerUtil.selectFeature((String)includeColumn, (List)features, (SkLearnEncoder)encoder);
            }
            includeFeatures.add(includeFeature);
        }
        if (transformer instanceof FixImbalancer) {
            FixImbalancer fixImbalancer = (FixImbalancer)transformer;
            return features;
        }
        List transformedIncludeFeatures = transformer.encode(includeFeatures, encoder);
        boolean replaceFeatures = false;
        if (transformer instanceof SimpleImputer) {
            SimpleImputer simpleImputer = (SimpleImputer)transformer;
            replaceFeatures = true;
        } else if (transformer instanceof RareCategoryGrouping) {
            RareCategoryGrouping rareCategoryGrouping = (RareCategoryGrouping)transformer;
            replaceFeatures = true;
        } else if (transformer instanceof BaseEncoder) {
            BaseEncoder baseEncoder = (BaseEncoder)transformer;
            replaceFeatures = true;
        }
        if (replaceFeatures) {
            List<List<Feature>> transformedIncludeFeatureGroups = TransformerWrapper.groupByField(transformedIncludeFeatures);
            ClassDictUtil.checkSize((Collection[])new Collection[]{includeFeatures, transformedIncludeFeatureGroups});
            ArrayList<List<Feature>> result = new ArrayList<List<Feature>>(features);
            for (int i = 0; i < include.size(); ++i) {
                String includeColumn = include.get(i);
                int index = featureNames.indexOf(includeColumn);
                List<Feature> transformedIncludeFeatureGroup = transformedIncludeFeatureGroups.get(i);
                result.set(index, transformedIncludeFeatureGroup);
            }
            return result.stream().flatMap(element -> {
                if (element instanceof List) {
                    List featureGroup = (List)element;
                    return featureGroup.stream();
                }
                Feature feature = (Feature)element;
                return Stream.of(feature);
            }).collect(Collectors.toList());
        }
        Label label = encoder.getLabel();
        ArrayList<Feature> result = new ArrayList<Feature>(transformedIncludeFeatures);
        if (label != null && (labelFeature = FeatureUtil.findLabelFeature((List)features, (ScalarLabel)(scalarLabel = (ScalarLabel)label))) != null) {
            result.add(labelFeature);
        }
        return result;
    }

    public List<String> getFeatureNames() {
        return this.getStringList("_feature_names_in");
    }

    public List<String> getExclude() {
        return this.getStringList("_exclude");
    }

    public List<String> getInclude() {
        return this.getStringList("_include");
    }

    public String getTargetName() {
        return this.getString("target_name_");
    }

    public Boolean getTrainOnly() {
        return this.getOptionalBoolean("_train_only", Boolean.FALSE);
    }

    public Transformer getTransformer() {
        return (Transformer)this.get("transformer", Transformer.class);
    }

    private static List<List<Feature>> groupByField(List<Feature> features) {
        ArrayList<List<Feature>> result = new ArrayList<List<Feature>>();
        Field prevField = null;
        ArrayList<Feature> fieldFeatures = null;
        for (Feature feature : features) {
            Field field = feature.getField();
            if (!Objects.equals(field, prevField)) {
                fieldFeatures = new ArrayList<Feature>();
                fieldFeatures.add(feature);
                result.add(fieldFeatures);
            } else {
                fieldFeatures.add(feature);
            }
            prevField = field;
        }
        return result;
    }
}

