/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.algebra.functions.combination;

import cz.cvut.fel.ida.algebra.functions.ActivationFcn;
import cz.cvut.fel.ida.algebra.functions.Transformation;
import cz.cvut.fel.ida.algebra.functions.combination.Softmax;
import cz.cvut.fel.ida.algebra.values.ScalarValue;
import cz.cvut.fel.ida.algebra.values.Value;
import cz.cvut.fel.ida.algebra.values.VectorValue;
import java.util.Arrays;
import java.util.List;
import java.util.logging.Logger;

public class Sparsemax
extends Softmax {
    private static final Logger LOG = Logger.getLogger(Sparsemax.class.getName());

    @Override
    public Transformation replaceWithSingleton() {
        return Transformation.Singletons.sparsemax;
    }

    @Override
    public Value evaluate(List<Value> inputs) {
        double[] exps = this.getProbabilities(inputs);
        VectorValue output = new VectorValue(exps);
        return output;
    }

    @Override
    public double[] getProbabilities(double[] input) {
        double[] z_sorted = Arrays.copyOf(input, input.length);
        Arrays.sort(z_sorted);
        int length = z_sorted.length;
        double[] bound = new double[length];
        for (int i = 0; i < length; ++i) {
            bound[i] = z_sorted[length - 1 - i] * (double)(i + 1) + 1.0;
        }
        double sum = 0.0;
        double[] cumsum = new double[length];
        for (int i = 0; i < length; ++i) {
            cumsum[i] = sum += z_sorted[z_sorted.length - 1 - i];
        }
        int k = -1;
        double sparse_sum = 0.0;
        for (int i = 0; i < length; ++i) {
            if (!(bound[i] > cumsum[i])) continue;
            k = i + 1;
            sparse_sum += z_sorted[length - 1 - i];
        }
        double threshold = (sparse_sum - 1.0) / (double)k;
        double[] out = new double[length];
        for (int i = 0; i < length; ++i) {
            double val = input[i] - threshold;
            out[i] = val > 0.0 ? val : 0.0;
        }
        return out;
    }

    @Override
    public double[] getProbabilities(List<Value> inputs) {
        if (inputs.size() == 1 && inputs.get(0) instanceof VectorValue) {
            return this.getProbabilities(((VectorValue)inputs.get((int)0)).values);
        }
        double[] z_values = new double[inputs.size()];
        for (int i = 0; i < inputs.size(); ++i) {
            double val;
            z_values[i] = val = ((ScalarValue)inputs.get((int)i)).value;
        }
        return this.getProbabilities(z_values);
    }

    @Override
    public double[] getGradient(double[] exps) {
        int support = 0;
        for (int i = 0; i < exps.length; ++i) {
            if (!(exps[i] > 0.0)) continue;
            ++support;
        }
        double[] diffs = new double[exps.length * exps.length];
        for (int i = 0; i < exps.length; ++i) {
            if (exps[i] == 0.0) continue;
            int tmpIndex = i * exps.length;
            for (int j = 0; j < exps.length; ++j) {
                if (exps[j] == 0.0) continue;
                diffs[tmpIndex + j] = i == j ? 1.0 - 1.0 / (double)support : -1.0 / (double)support;
            }
        }
        return diffs;
    }

    @Override
    public boolean isPermutationInvariant() {
        return false;
    }

    @Override
    public boolean isComplex() {
        return true;
    }

    @Override
    public ActivationFcn.State getState(boolean singleInput) {
        if (singleInput) {
            return new Softmax.TransformationState(Transformation.Singletons.sparsemax);
        }
        return new Softmax.CombinationState(Transformation.Singletons.sparsemax);
    }
}

