/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.algebra.functions.transformation.joint;

import cz.cvut.fel.ida.algebra.functions.ActivationFcn;
import cz.cvut.fel.ida.algebra.functions.Transformation;
import cz.cvut.fel.ida.algebra.values.Value;
import cz.cvut.fel.ida.algebra.values.VectorValue;
import java.util.logging.Logger;

public class Normalization
implements Transformation {
    private static final Logger LOG = Logger.getLogger(Normalization.class.getName());
    static double eps = 1.0E-10;

    @Override
    public ActivationFcn replaceWithSingleton() {
        return Transformation.Singletons.normalization;
    }

    @Override
    public Value evaluate(Value combinedInputs) {
        double sum = 0.0;
        int count = 0;
        for (Double summedInput : combinedInputs) {
            sum += summedInput.doubleValue();
            ++count;
        }
        double mean = sum / (double)count;
        double var = 0.0;
        for (Double input : combinedInputs) {
            double diff = input - mean;
            var += diff * diff;
        }
        double stdEps = Math.sqrt((var /= (double)count) + eps);
        return combinedInputs.apply(x -> (x - mean) / stdEps);
    }

    @Override
    public Value differentiate(Value combinedInputs) {
        double sum = 0.0;
        int count = 0;
        for (Double summedInput : combinedInputs) {
            sum += summedInput.doubleValue();
            ++count;
        }
        double mean = sum / (double)count;
        double var = 0.0;
        for (Double input : combinedInputs) {
            double diff = input - mean;
            var += diff * diff;
        }
        double stdEps = Math.sqrt((var /= (double)count) + eps);
        double gradSigma = 0.0;
        for (Double x : combinedInputs) {
            gradSigma += x - mean;
        }
        gradSigma *= -0.5 * Math.pow(var + eps, -1.5);
        double invCount = 1.0 / (double)count;
        double gradMean = 0.0;
        for (Double x : combinedInputs) {
            gradMean += -1.0 / stdEps + gradSigma * invCount * 2.0 * (x - mean) * -1.0;
        }
        int i = 0;
        double[] grad = new double[count];
        for (Double x : combinedInputs) {
            grad[i] = 1.0 / stdEps + gradSigma * invCount * 2.0 * (x - mean) + gradMean * invCount;
        }
        return new VectorValue(grad);
    }

    @Override
    public ActivationFcn.State getState(boolean singleInput) {
        return new State(Transformation.Singletons.normalization);
    }

    public static class State
    extends Transformation.State {
        double mean;
        double var;
        int count;
        double stdEps;

        public State(Transformation transformation) {
            super(transformation);
        }

        @Override
        public void invalidate() {
            super.invalidate();
            this.mean = 0.0;
            this.var = 0.0;
            this.count = 0;
            this.stdEps = 0.0;
        }

        @Override
        public Value evaluate() {
            double sum = 0.0;
            this.count = 0;
            for (Double summedInput : this.input) {
                sum += summedInput.doubleValue();
                ++this.count;
            }
            this.mean = sum / (double)this.count;
            this.var = 0.0;
            for (Double input : this.input) {
                double diff = input - this.mean;
                this.var += diff * diff;
            }
            this.var /= (double)this.count;
            this.stdEps = Math.sqrt(this.var + eps);
            return this.input.apply(x -> (x - this.mean) / this.stdEps);
        }

        @Override
        public void ingestTopGradient(Value topGradient) {
            double gradSigma = 0.0;
            for (int j = 0; j < this.count; ++j) {
                gradSigma += topGradient.get(j) * (this.input.get(j) - this.mean);
            }
            gradSigma *= -0.5 * Math.pow(this.var + eps, -1.5);
            double invCount = 1.0 / (double)this.count;
            double gradMean = 0.0;
            for (int j = 0; j < this.count; ++j) {
                gradMean += topGradient.get(j) * (-1.0 / this.stdEps) + gradSigma * invCount * -2.0 * (this.input.get(j) - this.mean);
            }
            double[] grad = new double[this.count];
            for (int j = 0; j < this.count; ++j) {
                grad[j] = topGradient.get(j) * (1.0 / this.stdEps) + gradSigma * invCount * 2.0 * (this.input.get(j) - this.mean) + gradMean * invCount;
            }
            this.processedGradient = new VectorValue(grad);
        }
    }
}

