/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.neural.networks.computation.training.optimizers;

import cz.cvut.fel.ida.algebra.values.ScalarValue;
import cz.cvut.fel.ida.algebra.values.Value;
import cz.cvut.fel.ida.algebra.weights.Weight;
import cz.cvut.fel.ida.neural.networks.computation.training.optimizers.Optimizer;
import cz.cvut.fel.ida.setup.Settings;
import java.util.Collection;
import java.util.logging.Logger;

public class Adam
implements Optimizer {
    private static final Logger LOG = Logger.getLogger(Adam.class.getName());
    public ScalarValue learningRate;
    public final double beta1;
    public final double beta2;
    public final double epsilon;

    public Adam(Value learningRate) {
        this(learningRate, 0.9, 0.999, 1.0E-8);
    }

    public Adam(Value learningRate, double i_beta1, double i_beta2, double i_epsilon) {
        this.learningRate = (ScalarValue)learningRate;
        this.beta1 = i_beta1;
        this.beta2 = i_beta2;
        this.epsilon = i_epsilon;
    }

    @Override
    public void performGradientStep(Collection<Weight> updatedWeights, Value[] gradients, int iteration) {
        double fix1 = 1.0 / (1.0 - Math.pow(this.beta1, iteration));
        double fix2 = 1.0 / (1.0 - Math.pow(this.beta2, iteration));
        double lr = this.learningRate.value;
        for (Weight weight : updatedWeights) {
            double[] value = weight.value.getAsArray();
            double[] momentum = weight.momentum.getAsArray();
            double[] velocity = weight.velocity.getAsArray();
            double[] gradient = gradients[weight.index].getAsArray();
            for (int i = 0; i < value.length; ++i) {
                double grad = gradient[i];
                momentum[i] = momentum[i] * this.beta1 - grad * (1.0 - this.beta1);
                velocity[i] = velocity[i] * this.beta2 + grad * grad * (1.0 - this.beta2);
                int n = i;
                value[n] = value[n] + momentum[i] * fix1 * (-1.0 / (Math.sqrt(velocity[i] * fix2) + this.epsilon)) * lr;
            }
            weight.value.setAsArray(value);
            weight.momentum.setAsArray(momentum);
            weight.velocity.setAsArray(velocity);
        }
    }

    @Override
    public void restart(Settings settings) {
    }
}

