/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.neural.networks.computation.training.strategies.Hyperparameters;

import cz.cvut.fel.ida.algebra.values.ScalarValue;
import cz.cvut.fel.ida.algebra.values.Value;
import cz.cvut.fel.ida.learning.results.Progress;
import cz.cvut.fel.ida.neural.networks.computation.training.strategies.Hyperparameters.RestartingStrategy;
import cz.cvut.fel.ida.setup.Settings;
import java.util.logging.Logger;

public class DynamicRestartingStrategy
extends RestartingStrategy {
    private static final Logger LOG = Logger.getLogger(DynamicRestartingStrategy.class.getName());
    Settings.DataSelection dataSelection;
    int timeSpan = 100;
    int recalculation = 10;
    Value minDelta = new ScalarValue(1.0E-4);
    Value avgLossPast;
    Value avgLossPresent;
    Value pastLoss0 = new ScalarValue(0.0);
    Value pastLoss1 = new ScalarValue(0.0);
    Value minusOne = new ScalarValue(-1.0);

    public DynamicRestartingStrategy(Settings settings, boolean validationPossible) {
        this.timeSpan = settings.earlyStoppingPatience;
        this.recalculation = settings.resultsRecalculationEpochae;
        this.dataSelection = !validationPossible ? Settings.DataSelection.ONLINETRAIN : settings.dataSelection;
    }

    @Override
    public boolean continueRestart(Progress progress) {
        if (progress.getEpochCount() == 0) {
            return true;
        }
        if (progress.getEpochCount() == 1) {
            this.avgLossPresent = progress.getCurrentOnlineTrainingResults().error.clone();
            this.avgLossPast = progress.getCurrentOnlineTrainingResults().error.clone();
        }
        Value presentLoss = this.getLoss(progress.currentRestart, 0);
        if (progress.getEpochCount() < this.timeSpan) {
            this.avgLossPresent.incrementBy(presentLoss);
            return true;
        }
        if (progress.getEpochCount() < 2 * this.timeSpan) {
            this.avgLossPresent.incrementBy(presentLoss);
            this.pastLoss1 = this.getLoss(progress.currentRestart, this.timeSpan - 1);
            this.avgLossPast.incrementBy(this.pastLoss1);
            return true;
        }
        this.pastLoss0 = this.getLoss(progress.currentRestart, 2 * this.timeSpan - 1);
        this.pastLoss1 = this.getLoss(progress.currentRestart, this.timeSpan - 1);
        this.avgLossPast.incrementBy(this.pastLoss0.times(this.minusOne));
        this.avgLossPast.incrementBy(this.pastLoss1);
        this.avgLossPresent.incrementBy(this.pastLoss1.times(this.minusOne));
        this.avgLossPresent.incrementBy(presentLoss);
        if (this.avgLossPast.greaterThan(this.avgLossPresent.plus(this.minDelta))) {
            return true;
        }
        LOG.fine("Stopping this restart due to loss plateau: past loss " + String.valueOf(this.avgLossPast) + " vs. present: " + String.valueOf(presentLoss));
        return false;
    }

    Value getLoss(Progress.Restart restart, int stepsBack) {
        switch (this.dataSelection) {
            case ONLINETRAIN: {
                return restart.onlineTrainingResults.get((int)(restart.onlineTrainingResults.size() - 1 - stepsBack)).error;
            }
            case TRUETRAIN: {
                return restart.trueTrainingResults.get((int)(restart.trueTrainingResults.size() - 1 - (stepsBack /= this.recalculation))).error;
            }
            case VALIDATION: {
                return restart.validationResults.get((int)(restart.validationResults.size() - 1 - (stepsBack /= this.recalculation))).error;
            }
        }
        return restart.onlineTrainingResults.get((int)(restart.onlineTrainingResults.size() - 1 - stepsBack)).error;
    }

    @Override
    public void nextRestart() {
    }
}

