/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.neural.networks.computation.training.strategies.trainers;

import cz.cvut.fel.ida.algebra.values.Value;
import cz.cvut.fel.ida.algebra.weights.Weight;
import cz.cvut.fel.ida.learning.results.Result;
import cz.cvut.fel.ida.neural.networks.computation.iteration.visitors.weights.WeightUpdater;
import cz.cvut.fel.ida.neural.networks.computation.training.NeuralModel;
import cz.cvut.fel.ida.neural.networks.computation.training.NeuralSample;
import cz.cvut.fel.ida.neural.networks.computation.training.optimizers.Optimizer;
import cz.cvut.fel.ida.neural.networks.computation.training.strategies.debugging.NeuralDebugging;
import cz.cvut.fel.ida.neural.networks.computation.training.strategies.trainers.ListTrainer;
import cz.cvut.fel.ida.neural.networks.computation.training.strategies.trainers.SequentialTrainer;
import cz.cvut.fel.ida.neural.networks.computation.training.strategies.trainers.StreamTrainer;
import cz.cvut.fel.ida.neural.networks.computation.training.strategies.trainers.Trainer;
import cz.cvut.fel.ida.setup.Settings;
import cz.cvut.fel.ida.utils.generic.Utilities;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

public class MiniBatchTrainer
extends Trainer {
    private static final Logger LOG = Logger.getLogger(MiniBatchTrainer.class.getName());
    int minibatchSize;
    List<SequentialTrainer> trainers;
    NeuralModel neuralModel;

    private MiniBatchTrainer() {
    }

    public MiniBatchTrainer(Settings settings, Optimizer optimizer, NeuralModel neuralModel, int minibatchSize) {
        super(settings, optimizer);
        this.minibatchSize = minibatchSize;
        this.neuralModel = neuralModel;
        this.trainers = new ArrayList<SequentialTrainer>(minibatchSize);
        for (int i = 0; i < minibatchSize; ++i) {
            this.trainers.add(new SequentialTrainer(settings, optimizer, neuralModel, i));
        }
    }

    public void setMinibatchSize(int minibatchSize) {
        this.minibatchSize = minibatchSize;
        int size = this.trainers.size();
        if (size >= minibatchSize) {
            return;
        }
        for (int i = size; i < minibatchSize; ++i) {
            this.trainers.add(new SequentialTrainer(this.settings, this.optimizer, this.neuralModel, i));
        }
    }

    private List<Result> minibatchParallelLearn(NeuralModel neuralModel, List<NeuralSample> sampleList) {
        int size = sampleList.size();
        HashSet<Weight> updatedWeights = new HashSet<Weight>();
        Value[] weightUpdates = new Value[neuralModel.maxWeightIndex + 1];
        if (size > this.minibatchSize) {
            LOG.severe("Minibatch size mismatch");
        }
        List<Result> results = IntStream.range(0, size).parallel().mapToObj(i -> this.evaluateAndBackprop(this.trainers.get(i), (NeuralSample)sampleList.get(i))).collect(Collectors.toList());
        for (int i2 = 0; i2 < size; ++i2) {
            WeightUpdater weightUpdater = this.trainers.get((int)i2).backpropagation.weightUpdater;
            Value[] updates = weightUpdater.weightUpdates;
            updatedWeights.addAll(weightUpdater.updatedWeightsOnly);
            for (int j = 0; j < weightUpdates.length; ++j) {
                if (weightUpdates[j] == null) {
                    weightUpdates[j] = updates[j];
                    continue;
                }
                if (updates[j] == null) continue;
                weightUpdates[j].incrementBy(updates[j]);
            }
        }
        this.optimizer.performGradientStep(updatedWeights, weightUpdates, this.iterationNumber);
        return results;
    }

    private List<Result> minibatchParallelEvaluate(List<NeuralSample> minibatch) {
        int size = minibatch.size();
        if (size > this.minibatchSize) {
            LOG.severe("Minibatch size mismatch");
        }
        return IntStream.range(0, size).parallel().mapToObj(i -> {
            SequentialTrainer trainer = this.trainers.get(i);
            NeuralSample sample = (NeuralSample)minibatch.get(i);
            return trainer.learnFromSample(this.neuralModel, sample, trainer.dropout, trainer.invalidation, trainer.evaluation, trainer.backpropagation);
        }).collect(Collectors.toList());
    }

    private Result evaluateAndBackprop(SequentialTrainer trainer, NeuralSample neuralSample) {
        if (this.settings.dropoutMode == Settings.DropoutMode.DROPOUT && this.settings.dropoutRate > 0.0) {
            trainer.dropoutSample(trainer.dropout, neuralSample);
        }
        trainer.invalidateSample(trainer.invalidation, neuralSample);
        Result result = trainer.evaluateSample(trainer.evaluation, neuralSample);
        trainer.backpropSample(trainer.backpropagation, result, neuralSample);
        if (this.settings.debugSampleTraining) {
            trainer.neuralDebugger.debug(neuralSample);
        }
        return result;
    }

    public class MiniBatchIterator
    implements Iterator<List<NeuralSample>> {
        List<NeuralSample> sampleList;
        int i = 0;

        public MiniBatchIterator(List<NeuralSample> sampleList) {
            this.sampleList = sampleList;
        }

        @Override
        public boolean hasNext() {
            return this.i < this.sampleList.size();
        }

        @Override
        public List<NeuralSample> next() {
            List<NeuralSample> neuralSamples = this.sampleList.subList(this.i, Math.min(this.i + MiniBatchTrainer.this.minibatchSize, this.sampleList.size()));
            this.i += MiniBatchTrainer.this.minibatchSize;
            return neuralSamples;
        }
    }

    public class MinibatchStreamTrainer
    implements StreamTrainer {
        @Override
        public Stream<Result> learnEpoch(NeuralModel neuralModel, Stream<NeuralSample> sampleStream) {
            ++MiniBatchTrainer.this.iterationNumber;
            if (sampleStream.isParallel()) {
                LOG.severe("The input sampleStream is parallel, but the training must perform sequential gradient steps!");
            }
            Stream<List> minibatchStream = StreamSupport.stream(new Utilities.BatchSpliterator(sampleStream.spliterator(), MiniBatchTrainer.this.minibatchSize), false);
            Stream<Result> resultStream = minibatchStream.map(batch -> MiniBatchTrainer.this.minibatchParallelLearn(neuralModel, (List<NeuralSample>)batch)).flatMap(Collection::stream);
            return resultStream;
        }

        @Override
        public void setupDebugger(NeuralDebugging trainingDebugger) {
            MiniBatchTrainer.this.neuralDebugger = trainingDebugger;
        }
    }

    public class MinibatchListTrainer
    implements ListTrainer {
        @Override
        public List<Result> learnEpoch(NeuralModel neuralModel, List<NeuralSample> sampleList) {
            ++MiniBatchTrainer.this.iterationNumber;
            ArrayList<Result> resultList = new ArrayList<Result>(sampleList.size());
            MiniBatchIterator miniBatchIterator = new MiniBatchIterator(sampleList);
            while (miniBatchIterator.hasNext()) {
                Object minibatch = miniBatchIterator.next();
                List<Result> results = MiniBatchTrainer.this.minibatchParallelLearn(neuralModel, (List<NeuralSample>)minibatch);
                resultList.addAll(results);
            }
            return resultList;
        }

        @Override
        public List<Result> evaluate(List<NeuralSample> trainingSet) {
            ArrayList<Result> resultList = new ArrayList<Result>(trainingSet.size());
            MiniBatchIterator miniBatchIterator = new MiniBatchIterator(trainingSet);
            while (miniBatchIterator.hasNext()) {
                Object minibatch = miniBatchIterator.next();
                List<Result> results = MiniBatchTrainer.this.minibatchParallelEvaluate((List<NeuralSample>)minibatch);
                resultList.addAll(results);
            }
            return resultList;
        }

        @Override
        public void restart(Settings settings) {
            MiniBatchTrainer.this.optimizer.restart(settings);
        }

        @Override
        public void setupDebugger(NeuralDebugging trainingDebugger) {
            MiniBatchTrainer.this.neuralDebugger = trainingDebugger;
        }
    }
}

