/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.neural.networks.structure.components.neurons.states;

import cz.cvut.fel.ida.algebra.functions.ActivationFcn;
import cz.cvut.fel.ida.algebra.functions.Combination;
import cz.cvut.fel.ida.algebra.functions.Transformation;
import cz.cvut.fel.ida.algebra.values.Value;
import cz.cvut.fel.ida.neural.networks.computation.iteration.visitors.states.StateVisiting;
import cz.cvut.fel.ida.neural.networks.computation.iteration.visitors.states.networks.ParentsTransfer;
import cz.cvut.fel.ida.neural.networks.computation.iteration.visitors.states.neurons.Backproper;
import cz.cvut.fel.ida.neural.networks.computation.iteration.visitors.states.neurons.Dropouter;
import cz.cvut.fel.ida.neural.networks.computation.iteration.visitors.states.neurons.Evaluator;
import cz.cvut.fel.ida.neural.networks.computation.iteration.visitors.states.neurons.Invalidator;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.Neurons;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.states.State;
import cz.cvut.fel.ida.neural.networks.structure.metadata.inputMappings.NeuronMapping;
import cz.cvut.fel.ida.neural.networks.structure.metadata.inputMappings.WeightedNeuronMapping;
import cz.cvut.fel.ida.setup.Settings;
import java.util.Arrays;
import java.util.List;
import java.util.logging.Logger;

public abstract class States
implements State {
    private static final Logger LOG = Logger.getLogger(States.class.getName());

    public static class NetworkParents
    implements State.Structure<Value>,
    State.Structure.Parents {
        int parentCount;
        State.Neural<Value> parentCounter;

        public Value accept(ParentsTransfer visitor) {
            visitor.parentsCount = this.parentCount;
            this.parentCounter.accept(visitor);
            return null;
        }

        public NetworkParents(State.Neural<Value> parentCounter, int parentCount) {
            this.parentCounter = parentCounter;
            this.parentCount = parentCount;
        }

        @Override
        public int getParentCount() {
            return this.parentCount;
        }

        @Override
        public void setParentCount(int parentCount) {
            this.parentCount = parentCount;
        }

        @Override
        public State.Neural<Value> getParentCounter() {
            return this.parentCounter;
        }

        @Override
        public void invalidate() {
        }

        public class WeightedInputsParents
        extends WeightedInputs
        implements State.Structure.Parents {
            public WeightedInputsParents(WeightedNeuronMapping<Neurons> inputs) {
                super(inputs);
            }

            @Override
            public WeightedNeuronMapping<Neurons> getWeightedMapping() {
                return this.inputs;
            }

            @Override
            public void invalidate() {
            }

            @Override
            public int getParentCount() {
                return NetworkParents.this.parentCount;
            }

            @Override
            public void setParentCount(int parentCount) {
                NetworkParents.this.parentCount = parentCount;
            }

            @Override
            public State.Neural<Value> getParentCounter() {
                return NetworkParents.this.parentCounter;
            }
        }

        public class InputsParents
        extends Inputs
        implements State.Structure.Parents {
            public InputsParents(NeuronMapping<Neurons> inputs) {
                super(inputs);
            }

            @Override
            public NeuronMapping<Neurons> getInputMapping() {
                return this.inputs;
            }

            @Override
            public void invalidate() {
            }

            @Override
            public int getParentCount() {
                return NetworkParents.this.parentCount;
            }

            @Override
            public void setParentCount(int parentCount) {
                NetworkParents.this.parentCount = parentCount;
            }

            @Override
            public State.Neural<Value> getParentCounter() {
                return NetworkParents.this.parentCounter;
            }
        }
    }

    public static class Outputs
    implements State.Structure.OutputNeuronMap {
        NeuronMapping<Neurons> outputs;

        @Override
        public NeuronMapping<Neurons> getOutputMapping() {
            return this.outputs;
        }

        @Override
        public void invalidate() {
        }
    }

    public static class WeightedInputs
    implements State.Structure.WeightedInputsMap {
        WeightedNeuronMapping<Neurons> inputs;

        public WeightedInputs(WeightedNeuronMapping<Neurons> inputs) {
            this.inputs = inputs;
        }

        @Override
        public WeightedNeuronMapping<Neurons> getWeightedMapping() {
            return this.inputs;
        }

        @Override
        public void invalidate() {
        }
    }

    public static class Inputs
    implements State.Structure.InputNeuronMap {
        NeuronMapping<Neurons> inputs;

        public Inputs(NeuronMapping<Neurons> inputs) {
            this.inputs = inputs;
        }

        @Override
        public NeuronMapping<Neurons> getInputMapping() {
            return this.inputs;
        }

        @Override
        public void invalidate() {
        }
    }

    public static final class DropoutStore
    extends ComputationStateStandard
    implements State.Neural.Computation.HasDropout {
        public double dropoutRate;
        public boolean isDropped;
        private boolean dropoutProcessed;
        private Settings settings;

        public DropoutStore(Settings settings, double dropoutRate, Combination combination, Transformation transformation) {
            super(combination, transformation);
            this.settings = settings;
            this.dropoutRate = dropoutRate;
        }

        public DropoutStore(Settings settings, Combination combination, Transformation transformation) {
            super(combination, transformation);
            this.settings = settings;
            this.dropoutRate = settings.dropoutRate;
        }

        @Override
        public void invalidate() {
            super.invalidate();
            this.isDropped = false;
            this.dropoutProcessed = false;
        }

        @Override
        public DropoutStore clone() {
            DropoutStore clone = (DropoutStore)super.clone();
            clone.dropoutRate = this.dropoutRate;
            clone.isDropped = this.isDropped;
            clone.dropoutProcessed = this.dropoutProcessed;
            clone.settings = this.settings;
            return clone;
        }

        public boolean ready4expansion(Dropouter visitor) {
            return !this.dropoutProcessed;
        }

        @Override
        public double getDropoutRate(StateVisiting visitor) {
            return this.dropoutRate;
        }

        @Override
        public void setDropoutRate(double rate) {
            this.dropoutRate = rate;
        }

        @Override
        public void setDropout(StateVisiting visitor) {
            this.isDropped = this.settings.random.nextDouble() < this.settings.dropoutRate;
            this.dropoutProcessed = true;
        }

        public final class ParentsDropoutStore
        extends ParentCounter
        implements State.Neural.Computation.HasDropout {
            public ParentsDropoutStore(Settings settings, Combination combination, Transformation transformation) {
                super(combination, transformation);
                DropoutStore.this.settings = settings;
            }

            public ParentsDropoutStore(Combination combination, Transformation transformation) {
                super(combination, transformation);
            }

            public ParentsDropoutStore(Settings settings, double dropoutRate, Combination combination, Transformation transformation) {
                super(combination, transformation);
                DropoutStore.this.settings = settings;
                DropoutStore.this.dropoutRate = dropoutRate;
            }

            @Override
            public ParentsDropoutStore clone() {
                ParentsDropoutStore clone = new ParentsDropoutStore(DropoutStore.this.settings, DropoutStore.this.dropoutRate, this.fcnState.getCombination(), this.fcnState.getTransformation());
                clone.parentCount = this.parentCount;
                clone.checked = this.checked;
                clone.calculated = this.calculated;
                return clone;
            }

            @Override
            public double getDropoutRate(StateVisiting visitor) {
                return DropoutStore.this.getDropoutRate(visitor);
            }

            @Override
            public void setDropoutRate(double rate) {
                DropoutStore.this.dropoutRate = rate;
            }

            @Override
            public void setDropout(StateVisiting visitor) {
                DropoutStore.this.setDropout(visitor);
            }

            @Override
            public void setParents(StateVisiting visitor, int parentCount) {
                this.parentCount = parentCount;
            }
        }
    }

    public static class ParentCounter
    extends ComputationStateStandard
    implements State.Neural.Computation.HasParents {
        public int parentCount;
        public int checked = 0;
        boolean calculated;

        public ParentCounter(Combination combination, Transformation transformation, int count) {
            super(combination, transformation);
            this.parentCount = count;
        }

        public ParentCounter(Combination combination, Transformation transformation) {
            super(combination, transformation);
        }

        @Override
        public void invalidate() {
            super.invalidate();
            this.checked = 0;
            this.calculated = false;
        }

        @Override
        public ParentCounter clone() {
            ParentCounter clone = (ParentCounter)super.clone();
            clone.parentCount = this.parentCount;
            clone.checked = this.checked;
            clone.calculated = this.calculated;
            return clone;
        }

        @Override
        public void storeGradient(Value gradient) {
            super.storeGradient(gradient);
            ++this.checked;
        }

        @Override
        public boolean ready4expansion(StateVisiting visitor) {
            if (visitor instanceof Backproper) {
                return this.ready4expansion((Backproper)visitor);
            }
            if (visitor instanceof Evaluator) {
                return this.ready4expansion((Evaluator)visitor);
            }
            if (visitor instanceof Invalidator) {
                return this.ready4expansion((Invalidator)visitor);
            }
            return true;
        }

        public boolean ready4expansion(Backproper visitor) {
            return this.checked == this.parentCount;
        }

        public boolean ready4expansion(Evaluator visitor) {
            return this.calculated;
        }

        public boolean ready4expansion(Invalidator visitor) {
            return true;
        }

        @Override
        public int getParents(StateVisiting visitor) {
            return this.parentCount;
        }

        @Override
        public int getChecked(StateVisiting visitor) {
            return this.checked;
        }

        @Override
        public void setChecked(StateVisiting visitor, int checked) {
            this.checked = checked;
        }

        @Override
        public void setParents(StateVisiting visitor, int parentCount) {
            this.parentCount = parentCount;
        }

        @Override
        public void setValue(Value value) {
            super.setValue(value);
            this.calculated = true;
        }
    }

    public static class SimpleValue
    implements State.Neural.Computation {
        Value outputValue;
        Value acumGradient;
        ActivationFcn.SimpleValueState fcnState;
        public boolean isLearnable = false;

        public SimpleValue(Value factValue) {
            this.outputValue = factValue;
            this.acumGradient = factValue.getForm();
            this.fcnState = new ActivationFcn.SimpleValueState(factValue);
        }

        @Override
        public void invalidate() {
            this.acumGradient.zero();
        }

        @Override
        public State.Neural.Computation clone() {
            return new SimpleValue(this.outputValue.clone());
        }

        @Override
        public ActivationFcn.SimpleValueState getFcnState() {
            return this.fcnState;
        }

        @Override
        public void setFcnState(ActivationFcn.State fcnState) {
            this.fcnState = (ActivationFcn.SimpleValueState)fcnState;
        }

        @Override
        public Value getValue() {
            return this.outputValue;
        }

        @Override
        public Value getGradient() {
            return this.acumGradient;
        }

        @Override
        public void setValue(Value value) {
            this.outputValue = value;
        }

        @Override
        public void setGradient(Value gradient) {
            this.acumGradient = gradient;
        }

        @Override
        public void cumulateValue(Value value) {
            this.outputValue = value;
        }

        @Override
        public void storeGradient(Value gradient) {
            if (this.isLearnable) {
                this.acumGradient.incrementBy(gradient);
            }
        }

        @Override
        public Value evaluate() {
            return this.outputValue;
        }

        @Override
        public Value initEval(List<Value> inputValues) {
            Value value;
            this.outputValue = value = this.fcnState.initEval(inputValues);
            this.acumGradient = value.getForm();
            return this.outputValue;
        }

        @Override
        public Combination getCombination() {
            return null;
        }

        @Override
        public Transformation getTransformation() {
            return null;
        }
    }

    public static class ComputationStateStandard
    implements State.Neural.Computation {
        ActivationFcn.State fcnState;
        Value outputValue;
        Value acumGradient;

        public ComputationStateStandard(Combination combination, Transformation transformation) {
            this.fcnState = ActivationFcn.State.getState(combination, transformation);
        }

        @Override
        public void invalidate() {
            this.outputValue = null;
            this.acumGradient.zero();
            this.fcnState.invalidate();
        }

        @Override
        public Combination getCombination() {
            return this.fcnState.getCombination();
        }

        @Override
        public Transformation getTransformation() {
            return this.fcnState.getTransformation();
        }

        @Override
        public ComputationStateStandard clone() {
            ComputationStateStandard clone = new ComputationStateStandard(this.fcnState.getCombination(), this.fcnState.getTransformation());
            clone.outputValue = this.outputValue.clone();
            clone.acumGradient = this.acumGradient.clone();
            return clone;
        }

        @Override
        public ActivationFcn.State getFcnState() {
            return this.fcnState;
        }

        @Override
        public void setFcnState(ActivationFcn.State fcnState) {
            this.fcnState = fcnState;
        }

        @Override
        public void setValue(Value value) {
            this.outputValue = value;
        }

        @Override
        public void setGradient(Value gradient) {
            this.acumGradient = gradient;
        }

        @Override
        public Value getValue() {
            return this.outputValue;
        }

        @Override
        public Value getGradient() {
            return this.acumGradient;
        }

        @Override
        public void cumulateValue(Value value) {
            if (value.isNaN()) {
                throw new RuntimeException("NaN value " + value.toDetailedString() + " obtained during forward pass at neuron state " + String.valueOf(this));
            }
            this.fcnState.cumulate(value);
        }

        @Override
        public void storeGradient(Value value) {
            if (value.isNaN()) {
                throw new RuntimeException("NaN gradient " + value.toDetailedString() + " obtained during backward pass at neuron state " + String.valueOf(this));
            }
            this.acumGradient.incrementBy(value);
        }

        @Override
        public Value evaluate() {
            return this.fcnState.evaluate();
        }

        @Override
        public Value initEval(List<Value> inputValues) {
            if (this.fcnState == null) {
                LOG.severe("No fcnState created (initialized) in Neuron State.");
            }
            Value value = this.fcnState.initEval(inputValues);
            if (this.outputValue != null) {
                LOG.warning("Repeated initEval initialization of computation State.");
                if (!Arrays.equals(value.size(), this.outputValue.size())) {
                    LOG.severe("Collision with previously inferred Value dimensions!");
                }
            }
            this.outputValue = value;
            this.acumGradient = value.getForm();
            return this.outputValue;
        }
    }

    public static final class ComputationStateComposite<T extends State.Neural.Computation>
    implements State.Neural<Value> {
        public final T[] states;
        Combination combination;
        Transformation transformation;

        public ComputationStateComposite(T[] states) {
            this.states = states;
        }

        @Override
        public State.Neural.Computation getComputationView(int index) {
            return this.states[index];
        }

        @Override
        public Combination getCombination() {
            return this.combination;
        }

        @Override
        public Transformation getTransformation() {
            return this.transformation;
        }

        public Value accept(StateVisiting.Computation visitor) {
            return this.states[visitor.stateIndex].accept(visitor);
        }

        @Override
        public void invalidate() {
            for (int i = 0; i < this.states.length; ++i) {
                this.states[i].invalidate();
            }
        }
    }
}

