/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.algebra.functions.combination;

import cz.cvut.fel.ida.algebra.functions.ActivationFcn;
import cz.cvut.fel.ida.algebra.functions.Combination;
import cz.cvut.fel.ida.algebra.values.Value;
import java.util.ArrayList;
import java.util.List;
import java.util.logging.Logger;

public class FunctionGraph
implements Combination {
    private static final Logger LOG = Logger.getLogger(FunctionGraph.class.getName());
    public FunctionGraphNode root = null;
    private final String name;

    public FunctionGraph(String name, FunctionGraphNode root) {
        this.root = root;
        this.name = name;
    }

    @Override
    public String getName() {
        return this.name;
    }

    @Override
    public Combination replaceWithSingleton() {
        return null;
    }

    @Override
    public Value evaluate(List<Value> inputs) {
        return null;
    }

    @Override
    public boolean isPermutationInvariant() {
        return false;
    }

    @Override
    public boolean isComplex() {
        return true;
    }

    @Override
    public ActivationFcn.State getState(boolean singleInput) {
        return new State(this);
    }

    public static class FunctionGraphNode {
        public int[] indices;
        public FunctionGraphNode[] nodes;
        public ActivationFcn function = null;
        public ActivationFcn.State state = null;

        public FunctionGraphNode(ActivationFcn function, FunctionGraphNode[] nodes, int[] indices) {
            this.function = function;
            this.nodes = nodes;
            this.indices = indices;
        }
    }

    public static class State
    extends Combination.InputArrayState {
        private List<Value> processedGradients = null;
        private final FunctionGraphNode root;

        public State(FunctionGraph combination) {
            super(combination);
            this.root = this.initStates(combination.root);
        }

        private FunctionGraphNode initStates(FunctionGraphNode node) {
            FunctionGraphNode[] nodes = new FunctionGraphNode[node.nodes.length];
            for (int i = 0; i < node.nodes.length; ++i) {
                if (node.nodes[i] == null) continue;
                nodes[i] = this.initStates(node.nodes[i]);
            }
            FunctionGraphNode root = new FunctionGraphNode(node.function, nodes, node.indices);
            root.state = root.function.getState(node.nodes.length == 1);
            return root;
        }

        @Override
        public Value evaluate() {
            return this.evaluate(this.root);
        }

        private Value evaluate(FunctionGraphNode node) {
            for (int i = 0; i < node.nodes.length; ++i) {
                if (node.nodes[i] == null) {
                    node.state.cumulate((Value)this.accumulatedInputs.get(node.indices[i]));
                    continue;
                }
                node.state.cumulate(this.evaluate(node.nodes[i]));
            }
            return node.state.evaluate();
        }

        @Override
        public void cumulate(Value value) {
            super.cumulate(value);
        }

        @Override
        public void ingestTopGradient(Value topGradient) {
            this.processedGradient = topGradient;
        }

        @Override
        public Value nextInputGradient() {
            if (this.i == 0) {
                this.nextInputGradient(this.root, this.processedGradient);
            }
            return this.processedGradients.get(this.i++);
        }

        private void nextInputGradient(FunctionGraphNode node, Value grad) {
            node.state.ingestTopGradient(grad);
            for (int i = 0; i < node.nodes.length; ++i) {
                if (node.nodes[i] != null) {
                    this.nextInputGradient(node.nodes[i], node.state.nextInputGradient());
                    continue;
                }
                Value gradAtIndex = this.processedGradients.get(node.indices[i]);
                if (gradAtIndex == null) {
                    this.processedGradients.set(node.indices[i], node.state.nextInputGradient());
                    continue;
                }
                this.processedGradients.set(node.indices[i], gradAtIndex.plus(node.state.nextInputGradient()));
            }
        }

        @Override
        public void invalidate() {
            super.invalidate();
            for (int i = 0; i < this.processedGradients.size(); ++i) {
                this.processedGradients.set(i, null);
            }
            this.invalidate(this.root);
        }

        private void invalidate(FunctionGraphNode node) {
            node.state.invalidate();
            for (int i = 0; i < node.nodes.length; ++i) {
                if (node.nodes[i] == null) continue;
                this.invalidate(node.nodes[i]);
            }
        }

        @Override
        public Value initEval(List<Value> values) {
            this.processedGradients = new ArrayList<Value>(values.size());
            for (int i = 0; i < values.size(); ++i) {
                this.processedGradients.add(null);
            }
            this.accumulatedInputs = (ArrayList)values;
            this.accumulatedInputs.trimToSize();
            this.i = 0;
            return this.initEval(this.root, values);
        }

        private Value initEval(FunctionGraphNode node, List<Value> values) {
            ArrayList<Value> nextValues = new ArrayList<Value>(node.indices.length);
            for (int i = 0; i < node.indices.length; ++i) {
                if (node.nodes[i] == null) {
                    nextValues.add(values.get(node.indices[i]));
                    continue;
                }
                nextValues.add(this.initEval(node.nodes[i], values));
            }
            return node.state.initEval(nextValues);
        }
    }
}

