/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.neural.networks.structure.building.builders;

import cz.cvut.fel.ida.algebra.values.Value;
import cz.cvut.fel.ida.neural.networks.computation.iteration.visitors.neurons.StateInitializer;
import cz.cvut.fel.ida.neural.networks.computation.iteration.visitors.states.neurons.Evaluator;
import cz.cvut.fel.ida.neural.networks.computation.training.strategies.Hyperparameters.DropoutRateStrategy;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.BaseNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.Neurons;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.states.State;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.states.States;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.states.StatesCache;
import cz.cvut.fel.ida.neural.networks.structure.components.types.DetailedNetwork;
import cz.cvut.fel.ida.neural.networks.structure.metadata.inputMappings.NeuronMapping;
import cz.cvut.fel.ida.neural.networks.structure.metadata.inputMappings.WeightedNeuronMapping;
import cz.cvut.fel.ida.setup.Settings;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.logging.Logger;

public class StatesBuilder {
    private static final Logger LOG = Logger.getLogger(StatesBuilder.class.getName());
    private Settings settings;

    public StatesBuilder(Settings settings) {
        this.settings = settings;
    }

    public void initializeStates(DetailedNetwork<State.Structure> detailedNetwork) {
        int stateIndex = 0;
        StateInitializer stateInitializer = new StateInitializer(detailedNetwork, new Evaluator(stateIndex));
        for (int i = 0; i < detailedNetwork.allNeuronsTopologic.size(); ++i) {
            BaseNeuron neuron = (BaseNeuron)detailedNetwork.allNeuronsTopologic.get(i);
            if (neuron.getComputationView(stateIndex).getValue() != null) continue;
            try {
                neuron.visit(stateInitializer);
                continue;
            }
            catch (ArithmeticException ex) {
                LOG.severe("Arithmetic exception at neuron: " + neuron.toString());
                throw ex;
            }
            catch (Exception ex) {
                LOG.severe("Exception at neuron state building (StatesBuilder): " + ex.toString());
                throw ex;
            }
        }
    }

    public void setupDropoutStates(DetailedNetwork<State.Structure> detailedNetwork) {
        DropoutRateStrategy dropoutRateStrategy = new DropoutRateStrategy(this.settings);
        for (int i = detailedNetwork.allNeuronsTopologic.size() - 1; i > 0; --i) {
            BaseNeuron neuron = (BaseNeuron)detailedNetwork.allNeuronsTopologic.get(i);
            if (neuron.layer == 0) {
                neuron.layer = 1;
            }
            dropoutRateStrategy.setDropout(neuron);
            Iterator inputs = detailedNetwork.getInputs(neuron);
            while (inputs.hasNext()) {
                Neurons next = (Neurons)inputs.next();
                if (next.getLayer() >= neuron.layer + 1) continue;
                next.setLayer(neuron.layer + 1);
            }
        }
    }

    public boolean makeParallel(BaseNeuron neuron) {
        State.Neural.Computation state = neuron.getComputationView(0);
        if (this.settings.parallelTraining && !(neuron.getRawState() instanceof States.ComputationStateComposite)) {
            States.ComputationStateComposite<State.Neural.Computation> compositeState = State.createCompositeState(state, this.settings.minibatchSize);
            neuron.setState(compositeState);
            return true;
        }
        return false;
    }

    public int makeSharedStatesRecursively(DetailedNetwork<State.Structure> detailedNetwork) {
        int sharedCount = 0;
        for (int i = detailedNetwork.allNeuronsTopologic.size() - 1; i > 0; --i) {
            BaseNeuron neuron = (BaseNeuron)detailedNetwork.allNeuronsTopologic.get(i);
            if (!neuron.isShared) continue;
            ++sharedCount;
            this.makeParallel(neuron);
            Iterator inputs = detailedNetwork.getInputs(neuron);
            while (inputs.hasNext()) {
                ((Neurons)inputs.next()).setShared(true);
            }
        }
        return sharedCount;
    }

    public void addLinkedInputsToNetworkStates(DetailedNetwork<State.Structure> neuralNetwork) {
        neuralNetwork.extraInputMapping.forEach((neuron, inputs) -> {
            if (inputs instanceof NeuronMapping) {
                States.Inputs inputsState = new States.Inputs((NeuronMapping<Neurons>)inputs);
                neuralNetwork.addState((Neurons)neuron, inputsState);
            } else if (inputs instanceof WeightedNeuronMapping) {
                States.WeightedInputs weightedInputsState = new States.WeightedInputs((WeightedNeuronMapping)inputs);
                neuralNetwork.addState((Neurons)neuron, weightedInputsState);
            }
        });
    }

    public void setupParentStateNumbers(DetailedNetwork<State.Structure> network) {
        Map<BaseNeuron, NeuronMapping<Neurons>> neuronOutputs = network.outputMapping;
        neuronOutputs.forEach((neuron, outputs) -> {
            State.Neural.Computation state = neuron.getComputationView(0);
            if (state instanceof State.Neural.Computation.HasParents) {
                State.Neural.Computation.HasParents parentsState = (State.Neural.Computation.HasParents)((Object)state);
                int parents = parentsState.getParents(null);
                if (parents != 0 && parents != outputs.getLastList().size()) {
                    neuron.setShared(true);
                    neuron.sharedAfterCreation = true;
                    if (this.settings.parallelTraining) {
                        boolean bl = this.makeParallel((BaseNeuron)neuron);
                    }
                    Object rawState = neuron.getRawState();
                    States.NetworkParents networkParents = new States.NetworkParents((State.Neural<Value>)rawState, outputs.getLastList().size());
                    network.addState((Neurons)neuron, networkParents);
                } else if (parents == 0) {
                    parentsState.setParents(null, outputs.getLastList().size());
                }
            }
        });
    }

    private State.Structure createFinalState(List<State.Structure> structures) {
        if (structures.size() == 1) {
            return structures.get(0);
        }
        if (structures.isEmpty()) {
            return null;
        }
        State.Structure<NeuronMapping<Neurons>> result = null;
        boolean parents = false;
        boolean inputs = false;
        boolean weightedInputs = false;
        boolean outputs = false;
        State.Structure.Parents hasParents = null;
        State.Structure.InputNeuronMap inputNeuronMap = null;
        State.Structure.WeightedInputsMap weightedInputsMap = null;
        for (State.Structure structure : structures) {
            if (structure instanceof State.Structure.Parents) {
                parents = true;
                hasParents = (State.Structure.Parents)((Object)structure);
                continue;
            }
            if (structure instanceof State.Structure.InputNeuronMap) {
                inputs = true;
                inputNeuronMap = (State.Structure.InputNeuronMap)structure;
                continue;
            }
            if (structure instanceof State.Structure.WeightedInputsMap) {
                weightedInputs = true;
                weightedInputsMap = (State.Structure.WeightedInputsMap)structure;
                continue;
            }
            if (!(structure instanceof State.Structure.OutputNeuronMap)) continue;
            outputs = true;
        }
        if (parents && inputs && !weightedInputs && !outputs) {
            result = new States.NetworkParents.InputsParents(new States.NetworkParents(hasParents.getParentCounter(), hasParents.getParentCount()), inputNeuronMap.getInputMapping());
        } else if (parents && !inputs && weightedInputs && !outputs) {
            result = new States.NetworkParents.WeightedInputsParents(new States.NetworkParents(hasParents.getParentCounter(), hasParents.getParentCount()), weightedInputsMap.getWeightedMapping());
        }
        return result;
    }

    public DetailedNetwork<State.Structure> setupFinalStatesCache(DetailedNetwork<State.Structure> neuralNetwork) {
        State.Structure[] structureStates;
        Map<Neurons, List<State.Structure>> cumulativeStates = neuralNetwork.cumulativeStates;
        if (cumulativeStates.isEmpty()) {
            return neuralNetwork;
        }
        if (this.settings.iterationMode == Settings.IterationMode.TOPOLOGIC) {
            structureStates = new State.Structure[neuralNetwork.allNeuronsTopologic.size()];
            for (int i = 0; i < neuralNetwork.allNeuronsTopologic.size(); ++i) {
                State.Structure finalState2;
                BaseNeuron neuron2 = (BaseNeuron)neuralNetwork.allNeuronsTopologic.get(i);
                List<State.Structure> structures2 = cumulativeStates.get(neuron2);
                if (structures2 == null) continue;
                structureStates[i] = finalState2 = this.createFinalState(structures2);
            }
        } else {
            structureStates = new State.Structure[cumulativeStates.size()];
            TreeMap<Integer, State.Structure> finalStates = new TreeMap<Integer, State.Structure>();
            cumulativeStates.forEach((neuron, structures) -> {
                State.Structure finalState = this.createFinalState((List<State.Structure>)structures);
                finalStates.put(neuron.getIndex(), finalState);
            });
            finalStates.forEach((index, finalState) -> {
                structureStates[index.intValue()] = finalState;
            });
        }
        neuralNetwork.neuronStates = StatesCache.getCache((Settings)this.settings, (State.Structure[])structureStates);
        return neuralNetwork;
    }
}

