/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.algebra.functions.combination;

import cz.cvut.fel.ida.algebra.functions.ActivationFcn;
import cz.cvut.fel.ida.algebra.functions.Combination;
import cz.cvut.fel.ida.algebra.functions.Transformation;
import cz.cvut.fel.ida.algebra.functions.combination.Product;
import cz.cvut.fel.ida.algebra.values.ScalarValue;
import cz.cvut.fel.ida.algebra.values.Value;
import java.util.List;
import java.util.logging.Logger;

public class ElementProduct
extends Product
implements Transformation {
    private static final Logger LOG = Logger.getLogger(ElementProduct.class.getName());

    @Override
    public boolean isComplex() {
        return true;
    }

    @Override
    public Combination replaceWithSingleton() {
        return Combination.Singletons.elementProduct;
    }

    @Override
    public Value evaluate(List<Value> inputs) {
        Value mult = inputs.get(0).clone();
        for (int i = 1; i < inputs.size(); ++i) {
            try {
                mult.elementMultiplyBy(inputs.get(i));
                continue;
            }
            catch (ArithmeticException e) {
                mult = mult.elementTimes(inputs.get(i));
            }
        }
        return mult;
    }

    @Override
    public Value evaluate(Value combinedInputs) {
        double product = ElementProduct.getProduct(combinedInputs);
        return new ScalarValue(product);
    }

    @Override
    public Value differentiate(Value combinedInputs) {
        double product = ElementProduct.getProduct(combinedInputs);
        return ElementProduct.getGradient(combinedInputs, product);
    }

    public static double getProduct(Value combinedInputs) {
        double product = 1.0;
        for (Double element : combinedInputs) {
            product *= element.doubleValue();
        }
        return product;
    }

    public static Value getGradient(Value combinedInputs, double product) {
        Value form = combinedInputs.getForm();
        int i = 0;
        for (Double element : combinedInputs) {
            form.set(i++, product / element);
        }
        return form;
    }

    @Override
    public ActivationFcn.State getState(boolean singleInput) {
        if (singleInput) {
            return new TransformationState(Combination.Singletons.elementProduct);
        }
        return new AggregationState(Combination.Singletons.elementProduct);
    }

    public static class TransformationState
    extends Transformation.State {
        double product;

        public TransformationState(Transformation transformation) {
            super(transformation);
        }

        @Override
        public void invalidate() {
            super.invalidate();
            this.product = 1.0;
        }

        @Override
        public Value evaluate() {
            this.product = ElementProduct.getProduct(this.input);
            return new ScalarValue(this.product);
        }

        @Override
        public void ingestTopGradient(Value topGradient) {
            Value gradient = ElementProduct.getGradient(this.input, this.product);
            gradient.elementMultiplyBy(topGradient);
            this.processedGradient = gradient;
        }
    }

    public static class AggregationState
    extends Combination.InputArrayState {
        Value combinedInputs;

        public AggregationState(Combination combination) {
            super(combination);
        }

        @Override
        public Value evaluate() {
            this.combinedInputs = Combination.Singletons.elementProduct.evaluate(this.accumulatedInputs);
            return this.combinedInputs;
        }

        @Override
        public void invalidate() {
            super.invalidate();
            this.combinedInputs = null;
        }

        @Override
        public void ingestTopGradient(Value topGradient) {
            this.processedGradient = topGradient;
        }

        @Override
        public Value nextInputGradient() {
            Value gradient = this.combinedInputs.elementDivideBy((Value)this.accumulatedInputs.get(this.i++));
            gradient.elementMultiplyBy(this.processedGradient);
            return gradient;
        }
    }
}

