/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.algebra.functions.combination;

import cz.cvut.fel.ida.algebra.functions.ActivationFcn;
import cz.cvut.fel.ida.algebra.functions.Combination;
import cz.cvut.fel.ida.algebra.values.Value;
import java.util.List;
import java.util.logging.Logger;

public class Product
implements Combination {
    private static final Logger LOG = Logger.getLogger(Product.class.getName());

    @Override
    public Combination replaceWithSingleton() {
        return Combination.Singletons.product;
    }

    @Override
    public Value evaluate(List<Value> inputs) {
        Value mult = inputs.get(0).clone();
        for (int i = 1; i < inputs.size(); ++i) {
            mult = mult.times(inputs.get(i));
        }
        return mult;
    }

    @Override
    public boolean isPermutationInvariant() {
        return false;
    }

    @Override
    public boolean isComplex() {
        return true;
    }

    @Override
    public ActivationFcn.State getState(boolean singleInput) {
        return new State(Combination.Singletons.product);
    }

    public static class State
    extends Combination.InputArrayState {
        public State(Combination combination) {
            super(combination);
        }

        @Override
        public Value evaluate() {
            return Combination.Singletons.product.evaluate(this.accumulatedInputs);
        }

        @Override
        public void ingestTopGradient(Value topGradient) {
            this.processedGradient = topGradient;
        }

        @Override
        public Value nextInputGradient() {
            return this.derivativeFrom(this.i++, this.processedGradient);
        }

        public Value derivativeFrom(int index, Value topGradient) {
            int size = this.accumulatedInputs.size();
            Value left = null;
            Value right = null;
            for (int i = 0; i < size; ++i) {
                if (i == index) continue;
                if (i < index) {
                    if (left == null) {
                        left = (Value)this.accumulatedInputs.get(i);
                        continue;
                    }
                    left = left.times((Value)this.accumulatedInputs.get(i));
                    continue;
                }
                right = right == null ? (Value)this.accumulatedInputs.get(i) : right.times((Value)this.accumulatedInputs.get(i));
            }
            if (left == null) {
                Value transposedRight = right.transposedView();
                return topGradient.times(transposedRight);
            }
            if (right == null) {
                return topGradient.transposedTimes(left);
            }
            Value times = topGradient.transposedTimes(left);
            Value kronecker = right.transposedView().kroneckerTimes(times);
            return kronecker;
        }
    }
}

