/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.algebra.functions.combination;

import cz.cvut.fel.ida.algebra.functions.ActivationFcn;
import cz.cvut.fel.ida.algebra.functions.Aggregation;
import cz.cvut.fel.ida.algebra.functions.Combination;
import cz.cvut.fel.ida.algebra.functions.Transformation;
import cz.cvut.fel.ida.algebra.functions.transformation.joint.XMax;
import cz.cvut.fel.ida.algebra.values.MatrixValue;
import cz.cvut.fel.ida.algebra.values.ScalarValue;
import cz.cvut.fel.ida.algebra.values.Value;
import cz.cvut.fel.ida.algebra.values.VectorValue;
import cz.cvut.fel.ida.utils.math.VectorUtils;
import java.util.List;
import java.util.logging.Logger;

public class Softmax
implements Transformation,
Combination,
XMax,
Aggregation {
    private static final Logger LOG = Logger.getLogger(Softmax.class.getName());
    private final int[] aggregableTerms;

    public Softmax() {
        this.aggregableTerms = null;
    }

    public Softmax(int[] aggregableTerms) {
        this.aggregableTerms = aggregableTerms;
    }

    @Override
    public Transformation replaceWithSingleton() {
        return Transformation.Singletons.softmax;
    }

    @Override
    public Value evaluate(Value combinedInputs) {
        if (combinedInputs instanceof VectorValue) {
            VectorValue inputVector = (VectorValue)combinedInputs;
            double[] probabilities = this.getProbabilities(inputVector.values);
            return new VectorValue(probabilities);
        }
        throw new ClassCastException("Trying to apply softmax on something else than a Vector...");
    }

    @Override
    public Value differentiate(Value summedInputs) {
        if (summedInputs instanceof VectorValue) {
            VectorValue inputVector = (VectorValue)summedInputs;
            double[] exps = this.getProbabilities(inputVector.values);
            double[] diffs = this.getGradient(exps);
            return new MatrixValue(diffs, exps.length, exps.length);
        }
        throw new ClassCastException("Trying to differentiate softmax on something else than a Vector...");
    }

    @Override
    public Value evaluate(List<Value> inputs) {
        double[] exps = this.getProbabilities(inputs);
        VectorValue output = new VectorValue(exps);
        return output;
    }

    @Override
    public double[] getGradient(double[] exps) {
        double[] diffs = new double[exps.length * exps.length];
        for (int i = 0; i < exps.length; ++i) {
            int tmpIndex = i * exps.length;
            for (int j = 0; j < exps.length; ++j) {
                diffs[tmpIndex + j] = i == j ? exps[i] * (1.0 - exps[j]) : -exps[i] * exps[j];
            }
        }
        return diffs;
    }

    @Override
    public double[] getProbabilities(double[] input) {
        int i;
        double max = VectorUtils.max(input);
        double expsum = 0.0;
        double[] exps = new double[input.length];
        for (i = 0; i < input.length; ++i) {
            double exp;
            exps[i] = exp = Math.exp(input[i] - max);
            expsum += exp;
        }
        i = 0;
        while (i < exps.length) {
            int n = i++;
            exps[n] = exps[n] / expsum;
        }
        return exps;
    }

    @Override
    public double[] getProbabilities(List<Value> inputs) {
        int i;
        if (inputs.size() == 1 && inputs.get(0) instanceof VectorValue) {
            return this.getProbabilities(((VectorValue)inputs.get((int)0)).values);
        }
        double max = this.getMax(inputs);
        double expsum = 0.0;
        double[] exps = new double[inputs.size()];
        for (i = 0; i < inputs.size(); ++i) {
            double exp;
            exps[i] = exp = Math.exp(((ScalarValue)inputs.get((int)i)).value - max);
            expsum += exp;
        }
        i = 0;
        while (i < exps.length) {
            int n = i++;
            exps[n] = exps[n] / expsum;
        }
        return exps;
    }

    private double getMax(List<Value> inputs) {
        double max = Double.NEGATIVE_INFINITY;
        for (Value value : inputs) {
            if (!(((ScalarValue)value).value > max)) continue;
            max = ((ScalarValue)value).value;
        }
        return max;
    }

    @Override
    public boolean isPermutationInvariant() {
        return false;
    }

    @Override
    public boolean isSplittable() {
        return this.aggregableTerms != null && this.aggregableTerms.length != 0;
    }

    @Override
    public int[] aggregableTerms() {
        return this.aggregableTerms;
    }

    @Override
    public Value differentiate(List<Value> inputs) {
        LOG.warning("Directly calculating derivative of SOFTMAX fcn");
        return Value.ONE;
    }

    @Override
    public boolean isComplex() {
        return true;
    }

    @Override
    public Transformation singleInputVersion() {
        return Transformation.Singletons.constantOne;
    }

    @Override
    public ActivationFcn.State getState(boolean singleInput) {
        if (singleInput) {
            return new TransformationState(Transformation.Singletons.softmax);
        }
        return new CombinationState(Transformation.Singletons.softmax);
    }

    public static class CombinationState
    extends Combination.InputArrayState {
        double[] probabilities;

        public CombinationState(Combination combination) {
            super(combination);
        }

        @Override
        public void invalidate() {
            super.invalidate();
            this.probabilities = null;
        }

        @Override
        public Value evaluate() {
            XMax xMax = (XMax)((Object)this.combination);
            this.probabilities = xMax.getProbabilities(this.accumulatedInputs);
            return new VectorValue(this.probabilities);
        }

        @Override
        public void ingestTopGradient(Value topGradient) {
            Value inputFcnDerivative = this.gradient();
            this.processedGradient = inputFcnDerivative.times(topGradient);
        }

        @Override
        public Value nextInputGradient() {
            return new ScalarValue(this.processedGradient.get(this.i++));
        }

        public Value gradient() {
            XMax xMax = (XMax)((Object)this.combination);
            double[] gradient = xMax.getGradient(this.probabilities);
            return new MatrixValue(gradient, this.probabilities.length, this.probabilities.length);
        }
    }

    public static class TransformationState
    extends Transformation.State {
        double[] probabilities;

        public TransformationState(Transformation transformation) {
            super(transformation);
        }

        @Override
        public void invalidate() {
            super.invalidate();
            this.probabilities = null;
        }

        @Override
        public Value evaluate() {
            XMax xMax = (XMax)((Object)this.transformation);
            this.probabilities = xMax.getProbabilities(((VectorValue)this.input).values);
            return new VectorValue(this.probabilities);
        }

        @Override
        public Value gradient() {
            XMax xMax = (XMax)((Object)this.transformation);
            double[] gradient = xMax.getGradient(this.probabilities);
            return new MatrixValue(gradient, this.probabilities.length, this.probabilities.length);
        }
    }
}

