/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.neural.networks.computation.training.strategies;

import cz.cvut.fel.ida.algebra.values.Value;
import cz.cvut.fel.ida.algebra.values.inits.ValueInitializer;
import cz.cvut.fel.ida.learning.results.Progress;
import cz.cvut.fel.ida.learning.results.Result;
import cz.cvut.fel.ida.neural.networks.computation.iteration.actions.Evaluation;
import cz.cvut.fel.ida.neural.networks.computation.iteration.visitors.weights.WeightUpdater;
import cz.cvut.fel.ida.neural.networks.computation.training.NeuralModel;
import cz.cvut.fel.ida.neural.networks.computation.training.NeuralSample;
import cz.cvut.fel.ida.neural.networks.computation.training.optimizers.Optimizer;
import cz.cvut.fel.ida.neural.networks.computation.training.strategies.Hyperparameters.LearnRateDecayStrategy;
import cz.cvut.fel.ida.neural.networks.computation.training.strategies.TrainingStrategy;
import cz.cvut.fel.ida.neural.networks.computation.training.strategies.debugging.NeuralDebugging;
import cz.cvut.fel.ida.neural.networks.computation.training.strategies.trainers.ListTrainer;
import cz.cvut.fel.ida.neural.networks.computation.training.strategies.trainers.MiniBatchTrainer;
import cz.cvut.fel.ida.neural.networks.computation.training.strategies.trainers.SequentialTrainer;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.QueryNeuron;
import cz.cvut.fel.ida.setup.Settings;
import cz.cvut.fel.ida.utils.exporting.Exporter;
import cz.cvut.fel.ida.utils.generic.Pair;
import java.util.ArrayList;
import java.util.List;

public class PythonTrainingStrategy
extends TrainingStrategy {
    transient List<NeuralSample> samplesSet;
    transient SequentialTrainer trainer;
    transient ListTrainer listTrainer;
    MiniBatchTrainer miniBatchTrainer;
    ListTrainer minibatchListTrainer;
    ValueInitializer valueInitializer;
    Evaluation evaluation;
    LearnRateDecayStrategy learnRateDecay;
    int epochCount = 0;

    public PythonTrainingStrategy(Settings settings, NeuralModel model, Optimizer optimizer, LearnRateDecayStrategy learnRateDecay) {
        super(settings, model);
        this.trainer = new SequentialTrainer(settings, optimizer, this.currentModel);
        this.listTrainer = this.trainer.new SequentialTrainer.SequentialListTrainer();
        this.valueInitializer = ValueInitializer.getInitializer(settings);
        this.evaluation = this.trainer.getEvaluation();
        this.miniBatchTrainer = new MiniBatchTrainer(settings, optimizer, this.currentModel, 0);
        this.minibatchListTrainer = this.miniBatchTrainer.new MiniBatchTrainer.MinibatchListTrainer();
        this.learnRateDecay = learnRateDecay;
    }

    public SequentialTrainer getTrainer() {
        return this.trainer;
    }

    public NeuralModel getCurrentModel() {
        return this.currentModel;
    }

    public void setSamples(List<NeuralSample> samples) {
        this.samplesSet = samples;
    }

    public void resetParameters() {
        if (this.learnRateDecay != null) {
            this.learnRateDecay.restart();
        }
        this.epochCount = 0;
        this.listTrainer.restart(this.settings);
        this.currentModel.resetWeights(this.valueInitializer);
    }

    @Override
    public Pair<NeuralModel, Progress> train() {
        return null;
    }

    @Override
    public void setupDebugger(NeuralDebugging neuralDebugger) {
    }

    public List<Result> learnSamples(int epochs, int minibatchSize) {
        return this.learnSamples(this.samplesSet, epochs, minibatchSize);
    }

    public List<Result> learnSamples(List<NeuralSample> samples, int epochs, int minibatchSize) {
        List<Result> results = null;
        if (epochs <= 0) {
            return new ArrayList<Result>();
        }
        ListTrainer trainer = this.listTrainer;
        if (minibatchSize > 1) {
            this.miniBatchTrainer.setMinibatchSize(minibatchSize);
            trainer = this.minibatchListTrainer;
        }
        for (int i = 0; i < epochs; ++i) {
            ++this.epochCount;
            if (this.learnRateDecay != null) {
                this.learnRateDecay.decay(this.epochCount);
            }
            results = trainer.learnEpoch(this.currentModel, samples);
        }
        return results;
    }

    public Result learnSample(NeuralSample sample) {
        this.trainer.invalidateSample(this.trainer.getInvalidation(), sample);
        Result result = this.trainer.evaluateSample(this.trainer.getEvaluation(), sample);
        WeightUpdater weightUpdater = this.trainer.backpropSample(this.trainer.getBackpropagation(), result, sample);
        this.trainer.updateWeights(this.currentModel, weightUpdater);
        return result;
    }

    public Value evaluateSample(NeuralSample sample) {
        this.trainer.invalidateSample(this.trainer.getInvalidation(), sample);
        return this.evaluation.evaluate((QueryNeuron)sample.query);
    }

    public List<Value> evaluateSamples(List<NeuralSample> samples, int minibatchSize) {
        ArrayList<Value> output = new ArrayList<Value>(samples.size());
        if (minibatchSize > 1) {
            this.miniBatchTrainer.setMinibatchSize(minibatchSize);
            for (Result result : this.minibatchListTrainer.evaluate(samples)) {
                output.add(result.getOutput());
            }
            return output;
        }
        for (NeuralSample sample : samples) {
            this.trainer.invalidateSample(this.trainer.getInvalidation(), sample);
            output.add(this.evaluation.evaluate((QueryNeuron)sample.query));
        }
        return output;
    }

    @Override
    public void export(Exporter exporter) {
    }

    @Override
    public String exportToJson() {
        return null;
    }
}

