/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.learning.crossvalidation.splitting;

import cz.cvut.fel.ida.algebra.values.Value;
import cz.cvut.fel.ida.learning.LearningSample;
import cz.cvut.fel.ida.learning.crossvalidation.splitting.Splitter;
import cz.cvut.fel.ida.setup.Settings;
import cz.cvut.fel.ida.utils.generic.Pair;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class StratifiedSplitter<T extends LearningSample>
implements Splitter<T> {
    private static final Logger LOG = Logger.getLogger(StratifiedSplitter.class.getName());
    private Settings settings;

    public StratifiedSplitter(Settings settings) {
        this.settings = settings;
    }

    @Override
    public List<Stream<T>> partition(Stream<T> samples, int foldCount) {
        List collect = samples.collect(Collectors.toList());
        samples.close();
        if (this.settings.shuffleBeforeFoldSplit) {
            Collections.shuffle(collect, this.settings.random);
        }
        List partition = this.partition(collect, foldCount);
        return partition.stream().map(Collection::stream).collect(Collectors.toList());
    }

    @Override
    public List<List<T>> partition(List<T> samples, int foldCount) {
        ArrayList<List<T>> folds = new ArrayList<List<T>>();
        for (int i = 0; i < foldCount; ++i) {
            folds.add(new ArrayList());
        }
        Map<Value, List<T>> classes = this.getClasses(samples);
        this.distributeUniformly(classes.values(), folds);
        return folds;
    }

    public List<T> getStratifiedSubset(List<T> samples, int appCount) {
        List<T> subset;
        if (appCount > samples.size()) {
            LOG.warning("Limiting samples to a greater number than there actually is!");
            appCount = samples.size();
        }
        if ((subset = this.representativeSubset(this.getClasses(samples), (double)appCount / (double)samples.size())).size() > appCount) {
            subset = subset.subList(0, appCount);
        }
        if (subset.size() == 0) {
            subset.add((LearningSample)samples.get(0));
        }
        return subset;
    }

    public void distributeUniformly(Collection<List<T>> classes, List<List<T>> folds) {
        for (List<T> classSamples : classes) {
            for (int i = 0; i < classSamples.size(); ++i) {
                folds.get(i % folds.size()).add((LearningSample)classSamples.get(i));
            }
        }
    }

    public List<T> representativeSubset(Map<Value, List<T>> classes, double percentage) {
        List subsets = classes.values().stream().map(l -> l.subList(0, (int)Math.round((double)l.size() * percentage))).collect(Collectors.toList());
        LOG.info("Calculated stratified subset from class distribution: " + String.valueOf(classes.entrySet().stream().map(entry -> String.valueOf(entry.getKey()) + ":" + ((List)entry.getValue()).size()).collect(Collectors.toList())));
        List collect = subsets.stream().flatMap(Collection::stream).collect(Collectors.toList());
        return collect;
    }

    @Override
    public Pair<List<T>, List<T>> partition(List<T> samples, double percentage) {
        int split = (int)Math.round(percentage * (double)samples.size());
        if (percentage != 1.0 && (split == 1 || split == 0 || split == samples.size())) {
            LOG.warning("Problem with samples partitioning, there are too few samples to be splitted nicely: " + split + " out of " + samples.size() + " (split percentage = " + percentage + ")");
        }
        List<T> train = this.representativeSubset(this.getClasses(samples), percentage);
        ArrayList<T> test = new ArrayList<T>(samples);
        boolean b = test.removeAll(train);
        return new Pair<List<T>, List<T>>(train, test);
    }

    public Map<Value, List<T>> getClasses(List<T> samples) {
        LinkedHashMap<Value, List<T>> classes = new LinkedHashMap<Value, List<T>>();
        for (LearningSample sample : samples) {
            List tList = classes.computeIfAbsent(sample.target, k -> new ArrayList());
            tList.add(sample);
        }
        return classes;
    }
}

