/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.neural.networks.computation.training.strategies.trainers;

import cz.cvut.fel.ida.learning.results.Result;
import cz.cvut.fel.ida.neural.networks.computation.iteration.actions.Backpropagation;
import cz.cvut.fel.ida.neural.networks.computation.iteration.actions.Evaluation;
import cz.cvut.fel.ida.neural.networks.computation.iteration.actions.IndependentNeuronProcessing;
import cz.cvut.fel.ida.neural.networks.computation.iteration.visitors.states.neurons.Dropouter;
import cz.cvut.fel.ida.neural.networks.computation.iteration.visitors.states.neurons.Invalidator;
import cz.cvut.fel.ida.neural.networks.computation.training.NeuralModel;
import cz.cvut.fel.ida.neural.networks.computation.training.NeuralSample;
import cz.cvut.fel.ida.neural.networks.computation.training.optimizers.Optimizer;
import cz.cvut.fel.ida.neural.networks.computation.training.strategies.debugging.NeuralDebugging;
import cz.cvut.fel.ida.neural.networks.computation.training.strategies.trainers.ListTrainer;
import cz.cvut.fel.ida.neural.networks.computation.training.strategies.trainers.StreamTrainer;
import cz.cvut.fel.ida.neural.networks.computation.training.strategies.trainers.Trainer;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.QueryNeuron;
import cz.cvut.fel.ida.setup.Settings;
import java.util.ArrayList;
import java.util.List;
import java.util.logging.Logger;
import java.util.stream.Stream;

public class SequentialTrainer
extends Trainer {
    private static final Logger LOG = Logger.getLogger(SequentialTrainer.class.getName());
    transient IndependentNeuronProcessing dropout;
    transient IndependentNeuronProcessing invalidation;
    transient Evaluation evaluation;
    transient Backpropagation backpropagation;

    public SequentialTrainer(Settings settings, Optimizer optimizer, NeuralModel neuralModel) {
        this(settings, optimizer, neuralModel, -1);
    }

    public SequentialTrainer(Settings settings, Optimizer optimizer, NeuralModel neuralModel, int index) {
        super(settings, optimizer);
        this.index = index;
        this.evaluation = new Evaluation(settings, index);
        this.backpropagation = new Backpropagation(settings, neuralModel, index);
        this.invalidation = new IndependentNeuronProcessing(settings, new Invalidator(index));
        this.dropout = new IndependentNeuronProcessing(settings, new Dropouter(settings, index));
    }

    protected SequentialTrainer() {
    }

    public IndependentNeuronProcessing getDropout() {
        return this.dropout;
    }

    public IndependentNeuronProcessing getInvalidation() {
        return this.invalidation;
    }

    public Evaluation getEvaluation() {
        return this.evaluation;
    }

    public void setEvaluation(Evaluation evaluation) {
        this.evaluation = evaluation;
    }

    public Backpropagation getBackpropagation() {
        return this.backpropagation;
    }

    public class SequentialStreamTrainer
    implements StreamTrainer {
        @Override
        public Stream<Result> learnEpoch(NeuralModel neuralModel, Stream<NeuralSample> sampleStream) {
            Stream<Result> resultStream = sampleStream.map(sample -> SequentialTrainer.this.learnFromSample(neuralModel, (NeuralSample)sample, SequentialTrainer.this.dropout, SequentialTrainer.this.invalidation, SequentialTrainer.this.evaluation, SequentialTrainer.this.backpropagation));
            return resultStream;
        }

        @Override
        public void setupDebugger(NeuralDebugging trainingDebugger) {
            SequentialTrainer.this.neuralDebugger = trainingDebugger;
        }
    }

    public class SequentialListTrainer
    implements ListTrainer {
        @Override
        public List<Result> learnEpoch(NeuralModel neuralModel, List<NeuralSample> sampleList) {
            ArrayList<Result> resultList = new ArrayList<Result>(sampleList.size());
            for (NeuralSample neuralSample : sampleList) {
                if (((QueryNeuron)neuralSample.query).neuron == null) {
                    LOG.warning("No query neuron - skipping backprop for this sample:" + neuralSample.toString());
                    continue;
                }
                Result result = SequentialTrainer.this.learnFromSample(neuralModel, neuralSample, SequentialTrainer.this.dropout, SequentialTrainer.this.invalidation, SequentialTrainer.this.evaluation, SequentialTrainer.this.backpropagation);
                resultList.add(result);
            }
            return resultList;
        }

        @Override
        public List<Result> evaluate(List<NeuralSample> trainingSet) {
            ArrayList<Result> resultList = new ArrayList<Result>(trainingSet.size());
            for (NeuralSample neuralSample : trainingSet) {
                SequentialTrainer.this.invalidateSample(SequentialTrainer.this.invalidation, neuralSample);
                Result result = SequentialTrainer.this.evaluateSample(SequentialTrainer.this.evaluation, neuralSample);
                LOG.finest(() -> String.valueOf(neuralSample) + " : " + String.valueOf(result));
                resultList.add(result);
            }
            return resultList;
        }

        @Override
        public void restart(Settings settings) {
            SequentialTrainer.this.optimizer.restart(settings);
        }

        @Override
        public void setupDebugger(NeuralDebugging trainingDebugger) {
            SequentialTrainer.this.neuralDebugger = trainingDebugger;
        }
    }
}

