/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.neural.networks.structure.components.types;

import cz.cvut.fel.ida.algebra.weights.Weight;
import cz.cvut.fel.ida.neural.networks.structure.components.NeuralSets;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.BaseNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.Neurons;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.WeightedNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.states.State;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.AtomNeurons;
import cz.cvut.fel.ida.neural.networks.structure.components.types.TopologicNetwork;
import cz.cvut.fel.ida.neural.networks.structure.metadata.NetworkMetadata;
import cz.cvut.fel.ida.neural.networks.structure.metadata.inputMappings.LinkedMapping;
import cz.cvut.fel.ida.neural.networks.structure.metadata.inputMappings.NeuronMapping;
import cz.cvut.fel.ida.neural.networks.structure.metadata.inputMappings.WeightedNeuronMapping;
import cz.cvut.fel.ida.utils.generic.Pair;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;
import org.jetbrains.annotations.Nullable;

public class DetailedNetwork<N extends State.Structure>
extends TopologicNetwork<N> {
    private static final Logger LOG = Logger.getLogger(DetailedNetwork.class.getName());
    @Nullable
    public Map<BaseNeuron, NeuronMapping<Neurons>> extraInputMapping;
    @Nullable
    public Map<BaseNeuron, NeuronMapping<Neurons>> outputMapping;
    @Nullable
    public NeuralSets neuralSets;
    public Map<Neurons, List<State.Structure>> cumulativeStates;
    public int sharedNeuronsCount;
    public boolean compressed;
    public boolean pruned;
    Boolean recursive;
    @Nullable
    NetworkMetadata metadata;

    public DetailedNetwork(String id, List<BaseNeuron<Neurons, State.Neural>> allNeurons) {
        super(id, allNeurons);
        this.cumulativeStates = new LinkedHashMap<Neurons, List<State.Structure>>();
    }

    public DetailedNetwork(String id, int size) {
        super(id, size);
        this.cumulativeStates = new LinkedHashMap<Neurons, List<State.Structure>>();
    }

    public DetailedNetwork(String id, List<BaseNeuron<Neurons, State.Neural>> allNeurons, NeuralSets neuralSets) {
        this(id, allNeurons);
        this.neuralSets = neuralSets;
    }

    public DetailedNetwork(String id, NeuralSets neuralSets, List<AtomNeurons> queryNeurons) {
        super(queryNeurons, id);
        this.neuralSets = neuralSets;
        this.cumulativeStates = new LinkedHashMap<Neurons, List<State.Structure>>();
    }

    public List<Weight> getAllWeights() {
        HashSet<Weight> allWeights = new HashSet<Weight>();
        for (BaseNeuron neuron : this.allNeuronsTopologic) {
            if (!(neuron instanceof WeightedNeuron)) continue;
            WeightedNeuron weightedNeuron = (WeightedNeuron)neuron;
            allWeights.addAll(weightedNeuron.getWeights());
            allWeights.add(weightedNeuron.getOffset());
        }
        return new ArrayList<Weight>(allWeights);
    }

    @Override
    public <T extends Neurons, S extends State.Neural> Pair<Iterator<T>, Iterator<Weight>> getInputs(WeightedNeuron<T, S> neuron) {
        WeightedNeuronMapping inputMapping = this.extraInputMapping != null ? (WeightedNeuronMapping)this.extraInputMapping.get(neuron) : null;
        if (inputMapping != null) {
            Iterator iterator = inputMapping.iterator();
            Iterator<Weight> weightIterator = inputMapping.weightIterator();
            return new Pair(iterator, weightIterator);
        }
        return super.getInputs(neuron);
    }

    @Override
    public <T extends Neurons, S extends State.Neural> Iterator<T> getInputs(BaseNeuron<T, S> neuron) {
        NeuronMapping<Neurons> inputMapping = this.extraInputMapping != null ? this.extraInputMapping.get(neuron) : null;
        if (inputMapping != null) {
            return inputMapping.iterator();
        }
        return neuron.getInputs().iterator();
    }

    @Override
    public <T extends Neurons, S extends State.Neural> Iterator<Neurons> getOutputs(BaseNeuron<T, S> neuron) {
        LinkedMapping mapping = this.outputMapping != null ? (LinkedMapping)this.outputMapping.get(neuron) : null;
        if (mapping != null) {
            return mapping.iterator();
        }
        return null;
    }

    public void addState(Neurons neuron, State.Structure state) {
        List states = this.cumulativeStates.putIfAbsent(neuron, new LinkedList());
        states.add(state);
    }

    public boolean isRecursive() {
        return this.recursive;
    }

    public void setSharedNeuronsCount(int sharedNeuronsCount) {
        this.hasSharedNeurons = sharedNeuronsCount > 0;
        this.sharedNeuronsCount = sharedNeuronsCount;
    }

    public void replaceInput(BaseNeuron<Neurons, State.Neural> parentNeuron, Neurons toReplace, Neurons replaceWith) {
        NeuronMapping<Neurons> inputMapping = this.extraInputMapping != null ? this.extraInputMapping.get(parentNeuron) : null;
        if (inputMapping != null) {
            inputMapping.replace(toReplace, replaceWith);
        } else {
            boolean replaced = false;
            for (int i = 0; i < parentNeuron.getInputs().size(); ++i) {
                if (!parentNeuron.getInputs().get(i).equals(toReplace)) continue;
                parentNeuron.getInputs().set(i, replaceWith);
                replaced = true;
            }
        }
    }

    public void replaceOutput(BaseNeuron<Neurons, State.Neural> child, Neurons middle, Neurons parent) {
        NeuronMapping<Neurons> outputs;
        if (this.outputMapping == null || (outputs = this.outputMapping.get(child)) == null) {
            LOG.severe("OutputMapping requested but missing!");
            return;
        }
        outputs.replace(middle, parent);
    }

    public <S extends State.Neural, T extends Neurons> void replaceInputWeight(WeightedNeuron<T, S> parentNeuron, T toReplace, Weight finalWeight) {
        WeightedNeuronMapping inputMapping = this.extraInputMapping != null ? (WeightedNeuronMapping)this.extraInputMapping.get(parentNeuron) : null;
        if (inputMapping != null) {
            Iterator iterator = inputMapping.iterator();
            WeightedNeuronMapping.WeightIterator weightIterator = (WeightedNeuronMapping.WeightIterator)inputMapping.weightIterator();
            while (iterator.hasNext()) {
                Neurons next = (Neurons)iterator.next();
                Weight nextW = weightIterator.next();
                if (!next.equals(toReplace)) continue;
                weightIterator.replace(finalWeight);
            }
        } else {
            for (int i = 0; i < parentNeuron.getInputs().size(); ++i) {
                if (!((Neurons)parentNeuron.getInputs().get(i)).equals(toReplace)) continue;
                parentNeuron.getWeights().set(i, finalWeight);
            }
        }
    }

    public DetailedNetwork emptyCopy(String id) {
        DetailedNetwork<N> copy = new DetailedNetwork<N>(id, 0);
        copy.pruned = this.pruned;
        copy.compressed = this.compressed;
        copy.neuralSets = this.neuralSets;
        copy.outputMapping = this.outputMapping;
        copy.extraInputMapping = this.extraInputMapping;
        copy.containsInputMasking = this.containsInputMasking;
        copy.neuronStates = this.neuronStates;
        copy.metadata = this.metadata;
        return copy;
    }
}

