/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.neural.networks.computation.iteration.visitors.neurons;

import cz.cvut.fel.ida.algebra.functions.ActivationFcn;
import cz.cvut.fel.ida.algebra.values.Value;
import cz.cvut.fel.ida.algebra.weights.Weight;
import cz.cvut.fel.ida.neural.networks.computation.iteration.visitors.neurons.NeuronVisitor;
import cz.cvut.fel.ida.neural.networks.computation.iteration.visitors.states.StateVisiting;
import cz.cvut.fel.ida.neural.networks.computation.iteration.visitors.weights.WeightUpdater;
import cz.cvut.fel.ida.neural.networks.structure.components.NeuralNetwork;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.BaseNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.Neurons;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.WeightedNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.states.State;
import cz.cvut.fel.ida.utils.generic.Pair;
import java.util.Iterator;

public class Down
extends NeuronVisitor.Weighted {
    public Down(NeuralNetwork<State.Structure> network, StateVisiting.Computation topDown, WeightUpdater weightUpdater) {
        super(network, topDown, weightUpdater);
    }

    @Override
    public <T extends Neurons, S extends State.Neural> void visit(BaseNeuron<T, S> neuron) {
        State.Neural.Computation state = neuron.getComputationView(this.stateVisitor.stateIndex);
        Value topGradient = state.getGradient();
        ActivationFcn.State fcnState = state.getFcnState();
        fcnState.ingestTopGradient(topGradient);
        Iterator<T> inputs = this.network.getInputs(neuron);
        while (inputs.hasNext()) {
            Neurons input = (Neurons)inputs.next();
            Value inputGradient = fcnState.nextInputGradient();
            input.getComputationView(this.stateVisitor.stateIndex).storeGradient(inputGradient);
        }
    }

    @Override
    public <T extends Neurons, S extends State.Neural> void visit(WeightedNeuron<T, S> neuron) {
        State.Neural.Computation state = neuron.getComputationView(this.stateVisitor.stateIndex);
        Value topGradient = state.getGradient();
        ActivationFcn.State fcnState = state.getFcnState();
        fcnState.ingestTopGradient(topGradient);
        Pair<Iterator<T>, Iterator<Weight>> inputs = this.network.getInputs(neuron);
        if (neuron.offset.value != Value.ZERO) {
            Value offsetGradient = fcnState.nextInputGradient();
            this.weightUpdater.visit(neuron.offset, offsetGradient);
        }
        Iterator inputNeurons = (Iterator)inputs.r;
        Iterator inputWeights = (Iterator)inputs.s;
        while (inputNeurons.hasNext()) {
            Neurons input = (Neurons)inputNeurons.next();
            Weight weight = (Weight)inputWeights.next();
            State.Neural.Computation inputComputationView = input.getComputationView(this.stateVisitor.stateIndex);
            Value transpInputValue = inputComputationView.getValue().transposedView();
            Value inputGradient = fcnState.nextInputGradient();
            this.weightUpdater.visit(weight, inputGradient.times(transpInputValue));
            inputComputationView.storeGradient(weight.value.transposedTimes(inputGradient));
        }
    }
}

