/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.algebra.functions.aggregation;

import cz.cvut.fel.ida.algebra.functions.ActivationFcn;
import cz.cvut.fel.ida.algebra.functions.Aggregation;
import cz.cvut.fel.ida.algebra.functions.Combination;
import cz.cvut.fel.ida.algebra.functions.Transformation;
import cz.cvut.fel.ida.algebra.values.ScalarValue;
import cz.cvut.fel.ida.algebra.values.Value;
import java.util.List;
import java.util.logging.Logger;

public class Average
implements Aggregation {
    private static final Logger LOG = Logger.getLogger(Average.class.getName());

    @Override
    public Average replaceWithSingleton() {
        return Aggregation.Singletons.average;
    }

    @Override
    public Value evaluate(List<Value> inputs) {
        Value sum = inputs.get(0).clone();
        int len = inputs.size();
        for (int i = 1; i < len; ++i) {
            try {
                sum.incrementBy(inputs.get(i));
                continue;
            }
            catch (ArithmeticException e) {
                sum = sum.plus(inputs.get(i));
            }
        }
        sum.elementMultiplyBy((Value)new ScalarValue(1.0 / (double)inputs.size()));
        return sum;
    }

    @Override
    public Value differentiate(List<Value> inputs) {
        return new ScalarValue(1.0 / (double)inputs.size());
    }

    @Override
    public Value evaluate(Value combinedInputs) {
        double sum = 0.0;
        int count = 0;
        for (Double summedInput : combinedInputs) {
            sum += summedInput.doubleValue();
            ++count;
        }
        return new ScalarValue(sum / (double)count);
    }

    @Override
    public Value differentiate(Value combinedInputs) {
        int len = 1;
        for (int i : combinedInputs.size()) {
            len *= i;
        }
        return new ScalarValue(1 / len);
    }

    @Override
    public ActivationFcn.State getState(boolean singleInput) {
        if (singleInput) {
            return new TransformationState(Aggregation.Singletons.average);
        }
        return new AggregationState(Aggregation.Singletons.average);
    }

    public static class TransformationState
    extends Transformation.State {
        public TransformationState(Transformation transformation) {
            super(transformation);
        }
    }

    public static class AggregationState
    extends Aggregation.State {
        int count = 0;
        ScalarValue inverseCount;

        public AggregationState(Combination combination) {
            super(combination);
        }

        @Override
        public void cumulate(Value value) {
            this.combinedInputs.incrementBy(value);
        }

        @Override
        public void ingestTopGradient(Value topGradient) {
            this.processedGradient = topGradient.elementTimes((Value)this.inverseCount);
        }

        @Override
        public Value evaluate() {
            return this.combinedInputs.times((Value)this.inverseCount);
        }

        @Override
        public Value initEval(List<Value> values) {
            this.count = values.size();
            this.inverseCount = new ScalarValue(1.0 / (double)this.count);
            return super.initEval(values);
        }

        @Override
        public void invalidate() {
            this.combinedInputs.zero();
        }

        @Override
        public Value nextInputGradient() {
            return this.processedGradient;
        }
    }
}

