/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.algebra.functions.aggregation;

import cz.cvut.fel.ida.algebra.functions.ActivationFcn;
import cz.cvut.fel.ida.algebra.functions.Aggregation;
import cz.cvut.fel.ida.algebra.functions.Combination;
import cz.cvut.fel.ida.algebra.functions.Transformation;
import cz.cvut.fel.ida.algebra.values.Value;
import cz.cvut.fel.ida.algebra.values.VectorValue;
import cz.cvut.fel.ida.utils.generic.Pair;
import java.util.List;
import java.util.logging.Logger;

public class Minimum
implements Aggregation {
    private static final Logger LOG = Logger.getLogger(Minimum.class.getName());

    @Override
    public Minimum replaceWithSingleton() {
        return Aggregation.Singletons.minimum;
    }

    @Override
    public Value evaluate(List<Value> inputs) {
        Value min = inputs.get(0);
        for (int i = 1; i < inputs.size(); ++i) {
            Value value = inputs.get(i);
            if (!min.greaterThan(value)) continue;
            min = value;
        }
        return min;
    }

    @Override
    public Value differentiate(List<Value> inputs) {
        return Value.ONE;
    }

    @Override
    public Value evaluate(Value combinedInputs) {
        if (combinedInputs instanceof VectorValue) {
            double[] minValue = (double[])this.getMinValue((double[])((VectorValue)combinedInputs).values).s;
            return new VectorValue(minValue);
        }
        LOG.severe("Cannot calculate Min from other than VectorValue");
        return null;
    }

    @Override
    public Value differentiate(Value combinedInputs) {
        Value oneHot = this.evaluate(combinedInputs).apply(x -> x > 0.0 ? 1.0 : 0.0);
        return oneHot;
    }

    public Pair<Integer, double[]> getMinValue(double[] input) {
        double min = Double.MAX_VALUE;
        int min_index = -1;
        for (int i = 0; i < input.length; ++i) {
            if (!(input[i] < min)) continue;
            min = input[i];
            min_index = i;
        }
        double[] result = new double[input.length];
        result[min_index] = min;
        return new Pair<Integer, double[]>(min_index, result);
    }

    @Override
    public boolean isPermutationInvariant() {
        return true;
    }

    @Override
    public ActivationFcn.State getState(boolean singleInput) {
        if (singleInput) {
            return new TransformationState(Aggregation.Singletons.minimum);
        }
        return new AggregationState(Aggregation.Singletons.minimum);
    }

    public static class TransformationState
    extends Transformation.State {
        int minIndex;

        public TransformationState(Transformation transformation) {
            super(transformation);
        }

        @Override
        public void invalidate() {
            super.invalidate();
            this.minIndex = -1;
        }

        @Override
        public Value evaluate() {
            VectorValue inputVector = (VectorValue)this.input;
            Pair<Integer, double[]> mminValue = Aggregation.Singletons.minimum.getMinValue(inputVector.values);
            this.minIndex = (Integer)mminValue.r;
            return new VectorValue((double[])mminValue.s, inputVector.rowOrientation);
        }

        @Override
        public void ingestTopGradient(Value topGradient) {
            VectorValue gradient = ((VectorValue)this.input).getForm();
            gradient.set(this.minIndex, topGradient.get(this.minIndex));
            this.processedGradient = gradient;
        }
    }

    public static class AggregationState
    extends Aggregation.State {
        int minIndex = -1;
        int currentIndex = 0;

        public AggregationState(Combination combination) {
            super(combination);
        }

        @Override
        public void cumulate(Value value) {
            if (this.combinedInputs == null || this.combinedInputs.greaterThan(value)) {
                this.combinedInputs = value;
                this.minIndex = this.currentIndex;
            }
            ++this.currentIndex;
        }

        @Override
        public Value evaluate() {
            return this.combinedInputs;
        }

        @Override
        public void invalidate() {
            this.minIndex = -1;
            this.currentIndex = 0;
            this.combinedInputs = null;
        }

        @Override
        public int[] getInputMask() {
            int[] inputs = new int[]{this.minIndex};
            return inputs;
        }

        @Override
        public void ingestTopGradient(Value topGradient) {
            this.currentIndex = 0;
            this.processedGradient = topGradient;
        }

        @Override
        public Value nextInputGradient() {
            if (this.minIndex == this.currentIndex++) {
                return this.processedGradient;
            }
            return Value.ZERO;
        }
    }
}

