/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.neural.networks.structure.transforming;

import cz.cvut.fel.ida.algebra.functions.transformation.joint.Identity;
import cz.cvut.fel.ida.neural.networks.structure.components.NeuralNetwork;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.BaseNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.Neurons;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.QueryNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.WeightedNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.states.State;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.AggregationNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.types.DetailedNetwork;
import cz.cvut.fel.ida.neural.networks.structure.transforming.NetworkReducing;
import cz.cvut.fel.ida.setup.Settings;
import cz.cvut.fel.ida.utils.generic.Timing;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.logging.Logger;
import java.util.stream.Collectors;

public class LinearChainReducer
implements NetworkReducing {
    private static final Logger LOG = Logger.getLogger(LinearChainReducer.class.getName());
    private final transient Settings settings;
    int allNeurons = 0;
    int prunedNeurons = 0;
    Timing timing;

    public LinearChainReducer(Settings settings) {
        this.settings = settings;
        this.timing = new Timing();
    }

    @Override
    public NeuralNetwork reduce(DetailedNetwork<State.Structure> inet, List<QueryNeuron> outputs) {
        this.timing.tic();
        int prunings = 0;
        int sizeBefore = inet.allNeuronsTopologic.size();
        for (int i = sizeBefore - 1; i >= 0; --i) {
            boolean pruned;
            BaseNeuron neuron = (BaseNeuron)inet.allNeuronsTopologic.get(i);
            if (!this.settings.pruneEvenWeightedNeurons && neuron instanceof WeightedNeuron || this.settings.pruneOnlyIdentities && !(neuron instanceof AggregationNeuron) && !(neuron.getTransformation() instanceof Identity) || neuron.getTransformation() != null && neuron.getTransformation().changesShape() || !(pruned = this.prune(inet, neuron))) continue;
            ++prunings;
        }
        List<Neurons> collect = outputs.stream().map(s -> s.neuron).collect(Collectors.toList());
        NetworkReducing.supervisedNetReconstruction(inet, collect);
        int sizeAfter = inet.allNeuronsTopologic.size();
        this.allNeurons += sizeBefore;
        this.prunedNeurons += sizeAfter;
        LOG.info("LinearChainPruning reduced neurons from " + sizeBefore + " down to " + sizeAfter + " with prunings: " + prunings);
        this.timing.toc();
        return inet;
    }

    @Override
    public NeuralNetwork reduce(DetailedNetwork<State.Structure> inet, QueryNeuron outputNeuron) {
        return this.reduce(inet, Arrays.asList(outputNeuron));
    }

    @Override
    public void finish() {
        this.timing.finish();
    }

    private boolean prune(DetailedNetwork<State.Structure> inet, BaseNeuron<Neurons, State.Neural> middle) {
        ArrayList middleInputNeurons = new ArrayList();
        Iterator<Neurons> inputs = inet.getInputs(middle);
        inputs.forEachRemaining(middleInputNeurons::add);
        if (middleInputNeurons.size() == 1) {
            BaseNeuron child = (BaseNeuron)middleInputNeurons.get(0);
            Iterator<Neurons> parents = inet.getOutputs(middle);
            if (parents == null) {
                LOG.fine("Neuron has only 1 input but has no output, thus not pruning it (i.e., an output neuron).");
                return false;
            }
            ArrayList middleOutputNeurons = new ArrayList();
            parents.forEachRemaining(middleOutputNeurons::add);
            for (Neurons parent : middleOutputNeurons) {
                inet.replaceInput((BaseNeuron)parent, middle, child);
                if (middleOutputNeurons.size() > 1) {
                    inet.outputMapping.get(middle).removeLink(parent);
                    inet.outputMapping.get(child).addLink(parent);
                    continue;
                }
                inet.replaceOutput(child, middle, parent);
            }
            return true;
        }
        return false;
    }

    private void prune(DetailedNetwork<State.Structure> inet, WeightedNeuron<Neurons, State.Neural> middle) {
        LOG.warning("Trying to prune weighted neuron input");
    }
}

