/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.neural.networks.computation.training.strategies.trainers;

import cz.cvut.fel.ida.learning.results.Result;
import cz.cvut.fel.ida.neural.networks.computation.iteration.actions.Backpropagation;
import cz.cvut.fel.ida.neural.networks.computation.iteration.actions.Evaluation;
import cz.cvut.fel.ida.neural.networks.computation.iteration.actions.IndependentNeuronProcessing;
import cz.cvut.fel.ida.neural.networks.computation.iteration.visitors.weights.WeightUpdater;
import cz.cvut.fel.ida.neural.networks.computation.training.NeuralModel;
import cz.cvut.fel.ida.neural.networks.computation.training.NeuralSample;
import cz.cvut.fel.ida.neural.networks.computation.training.optimizers.Optimizer;
import cz.cvut.fel.ida.neural.networks.computation.training.strategies.debugging.NeuralDebugging;
import cz.cvut.fel.ida.neural.networks.structure.components.NeuralNetwork;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.QueryNeuron;
import cz.cvut.fel.ida.setup.Settings;
import cz.cvut.fel.ida.utils.exporting.Exportable;
import java.util.logging.Logger;

public class Trainer
implements Exportable {
    private static final Logger LOG = Logger.getLogger(Trainer.class.getName());
    protected Settings settings;
    int index;
    int iterationNumber;
    transient Optimizer optimizer;
    public NeuralDebugging neuralDebugger;

    public Trainer(Settings settings, Optimizer optimizer) {
        this.settings = settings;
        this.optimizer = optimizer;
        this.iterationNumber = 0;
    }

    public Trainer() {
    }

    protected Result learnFromSample(NeuralModel neuralModel, NeuralSample neuralSample, IndependentNeuronProcessing dropouter, IndependentNeuronProcessing invalidation, Evaluation evaluation, Backpropagation backpropagation) {
        if (this.settings.dropoutMode == Settings.DropoutMode.DROPOUT && this.settings.dropoutRate > 0.0) {
            this.dropoutSample(dropouter, neuralSample);
        }
        this.invalidateSample(invalidation, neuralSample);
        Result result = this.evaluateSample(evaluation, neuralSample);
        WeightUpdater weightUpdater = this.backpropSample(backpropagation, result, neuralSample);
        this.updateWeights(neuralModel, weightUpdater);
        if (this.settings.debugSampleTraining) {
            this.neuralDebugger.debug(neuralSample);
        }
        return result;
    }

    void dropoutSample(IndependentNeuronProcessing dropouter, NeuralSample neuralSample) {
        dropouter.process((NeuralNetwork)((QueryNeuron)neuralSample.query).evidence, ((QueryNeuron)neuralSample.query).neuron);
    }

    public void invalidateSample(IndependentNeuronProcessing invalidation, NeuralSample neuralSample) {
        ((NeuralNetwork)((QueryNeuron)neuralSample.query).evidence).initializeStatesCache(this.index);
        invalidation.process((NeuralNetwork)((QueryNeuron)neuralSample.query).evidence, ((QueryNeuron)neuralSample.query).neuron);
    }

    public Result evaluateSample(Evaluation evaluation, NeuralSample neuralSample) {
        return evaluation.evaluate(neuralSample);
    }

    public WeightUpdater backpropSample(Backpropagation backpropagation, Result evaluatedResult, NeuralSample neuralSample) {
        return backpropagation.backpropagate(neuralSample, evaluatedResult);
    }

    public synchronized void updateWeights(NeuralModel model, WeightUpdater weightUpdater) {
        this.optimizer.performGradientStep(model, weightUpdater, ++this.iterationNumber);
    }

    public void restart() {
        this.optimizer.restart(this.settings);
    }
}

