/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.algebra.functions.combination;

import cz.cvut.fel.ida.algebra.functions.ActivationFcn;
import cz.cvut.fel.ida.algebra.functions.Combination;
import cz.cvut.fel.ida.algebra.values.Value;
import cz.cvut.fel.ida.algebra.values.VectorValue;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;

public class CrossSum
implements Combination {
    private static final Logger LOG = Logger.getLogger(CrossSum.class.getName());

    @Override
    public Combination replaceWithSingleton() {
        return Combination.Singletons.crossSum;
    }

    @Override
    public Value evaluate(List<Value> inputs) {
        ArrayList<Value> values = new ArrayList<Value>(inputs.size());
        values.addAll(inputs);
        ArrayList<Double> outputVector = new ArrayList<Double>();
        this.combinationsRecursive(outputVector, 0.0, values);
        return new VectorValue(outputVector);
    }

    private void combinationsRecursive(List<Double> output, double sum, List<Value> values) {
        if (values.size() == 0) {
            output.add(sum);
            return;
        }
        Value removed = values.remove(0);
        for (Double next : removed) {
            this.combinationsRecursive(output, sum += next.doubleValue(), values);
            sum -= next.doubleValue();
        }
        values.add(0, removed);
    }

    @Override
    public boolean isComplex() {
        return true;
    }

    @Override
    public boolean isPermutationInvariant() {
        return false;
    }

    @Override
    public ActivationFcn.State getState(boolean singleInput) {
        return new State(Combination.Singletons.crossSum);
    }

    public static class State
    extends Combination.InputArrayState {
        private static Map<Mapping, Mapping> cache = new HashMap<Mapping, Mapping>();
        public int[][] mapping;
        int cross = 0;
        List<Value> inputGradients;

        public State(Combination combination) {
            super(combination);
        }

        @Override
        public Value initEval(List<Value> values) {
            Value eval = super.initEval(values);
            if (this.accumulatedInputs == null || this.accumulatedInputs.isEmpty()) {
                LOG.severe("CrossSum State not  initialized correctly");
            }
            this.initMapping(this.accumulatedInputs);
            if (this.inputGradients == null) {
                this.inputGradients = new ArrayList<Value>(this.accumulatedInputs.size());
                for (int j = 0; j < this.accumulatedInputs.size(); ++j) {
                    this.inputGradients.add(((Value)this.accumulatedInputs.get(j)).getForm());
                }
            }
            return eval;
        }

        @Override
        public void invalidate() {
            super.invalidate();
            for (Value inputGradient : this.inputGradients) {
                inputGradient.zero();
            }
        }

        @Override
        public Value evaluate() {
            return Combination.Singletons.crossSum.evaluate(this.accumulatedInputs);
        }

        @Override
        public void ingestTopGradient(Value topGradient) {
            this.processedGradient = topGradient;
            for (int i = 0; i < this.mapping.length; ++i) {
                double grad = this.processedGradient.get(i);
                int[] map = this.mapping[i];
                for (int j = 0; j < this.inputGradients.size(); ++j) {
                    Value inGrad = this.inputGradients.get(j);
                    inGrad.increment(map[j], grad);
                }
            }
        }

        @Override
        public Value nextInputGradient() {
            return this.inputGradients.get(this.i++);
        }

        public void initMapping(List<Value> inputValues) {
            int cross = 1;
            int[] sizes = new int[inputValues.size()];
            for (int i = 0; i < inputValues.size(); ++i) {
                Value value = inputValues.get(i);
                int oneSize = 1;
                int[] size = value.size();
                for (int j = 0; j < size.length; ++j) {
                    oneSize *= size[j];
                }
                sizes[i] = oneSize;
                cross *= oneSize;
            }
            this.mapping = new int[cross][inputValues.size()];
            this.combinations(0, new int[sizes.length], sizes);
            Mapping wrap = new Mapping(this.mapping);
            Mapping load = cache.get(wrap);
            if (load == null) {
                cache.put(wrap, wrap);
            }
        }

        private void combinations(int input, int[] current, int[] sizes) {
            if (input == sizes.length) {
                System.arraycopy(current, 0, this.mapping[this.cross], 0, sizes.length);
                ++this.cross;
                return;
            }
            int i = 0;
            while (i < sizes[input]) {
                current[input] = i++;
                this.combinations(input + 1, current, sizes);
            }
        }

        public static class Mapping {
            int[][] mapping;
            int hashcode = -1;

            public Mapping(int[][] mapping) {
                this.mapping = mapping;
            }

            public int hashCode() {
                if (this.hashcode != -1) {
                    return this.hashcode;
                }
                this.hashcode = Arrays.deepHashCode((Object[])this.mapping);
                return this.hashcode;
            }

            public boolean equals(Object obj) {
                return Arrays.deepEquals((Object[])this.mapping, (Object[])((Mapping)obj).mapping);
            }
        }
    }
}

