/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.pipelines.pipes.specific;

import cz.cvut.fel.ida.neural.networks.structure.building.NeuralProcessingSample;
import cz.cvut.fel.ida.neural.networks.structure.components.NeuralNetwork;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.QueryNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.states.State;
import cz.cvut.fel.ida.neural.networks.structure.components.types.DetailedNetwork;
import cz.cvut.fel.ida.neural.networks.structure.transforming.NetworkReducing;
import cz.cvut.fel.ida.pipelines.Pipe;
import cz.cvut.fel.ida.setup.Settings;
import cz.cvut.fel.ida.utils.generic.Utilities;
import cz.cvut.fel.ida.utils.math.collections.MultiList;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class PruningPipe
extends Pipe<Stream<NeuralProcessingSample>, Stream<NeuralProcessingSample>> {
    private static final Logger LOG = Logger.getLogger(PruningPipe.class.getName());
    NetworkReducing reducer;

    public PruningPipe(Settings settings) {
        super("NetworkPruningPipe", settings);
        this.reducer = NetworkReducing.getReducer(settings);
    }

    @Override
    public Stream<NeuralProcessingSample> apply(Stream<NeuralProcessingSample> neuralProcessingSampleStream) {
        if (this.settings.groundingMode == Settings.GroundingMode.GLOBAL) {
            List<NeuralProcessingSample> neuralProcessingSamples = Utilities.terminateSampleStream(neuralProcessingSampleStream);
            DetailedNetwork detailedNetwork = neuralProcessingSamples.get((int)0).detailedNetwork;
            List<QueryNeuron> queryNeurons = neuralProcessingSamples.stream().map(s -> (QueryNeuron)s.query).collect(Collectors.toList());
            NeuralNetwork reducedNetwork = this.reducer.reduce((DetailedNetwork<State.Structure>)detailedNetwork, queryNeurons);
            this.trueExport();
            return neuralProcessingSamples.stream().map(s -> {
                ((QueryNeuron)s.query).evidence = reducedNetwork;
                return s;
            });
        }
        if (!this.settings.oneQueryPerExample) {
            List<NeuralProcessingSample> processingSamples = Utilities.terminateSampleStream(neuralProcessingSampleStream);
            LinkedList allProcessingSamples = new LinkedList();
            MultiList<DetailedNetwork, NeuralProcessingSample> sampleMap = new MultiList<DetailedNetwork, NeuralProcessingSample>();
            for (NeuralProcessingSample neuralProcessingSample : processingSamples) {
                sampleMap.put(neuralProcessingSample.detailedNetwork, neuralProcessingSample);
            }
            for (Map.Entry entry : sampleMap.entrySet()) {
                DetailedNetwork detailedNetwork = (DetailedNetwork)entry.getKey();
                List samples = (List)entry.getValue();
                List<QueryNeuron> queryNeurons = samples.stream().map(s -> (QueryNeuron)s.query).collect(Collectors.toList());
                NeuralNetwork reducedNetwork = this.reducer.reduce((DetailedNetwork<State.Structure>)detailedNetwork, queryNeurons);
                for (NeuralProcessingSample sample2 : samples) {
                    ((QueryNeuron)sample2.query).evidence = reducedNetwork;
                }
                allProcessingSamples.addAll(samples);
            }
            this.trueExport();
            return allProcessingSamples.stream();
        }
        if (this.exporter != null) {
            neuralProcessingSampleStream.onClose(() -> this.trueExport());
        }
        return neuralProcessingSampleStream.map(sample -> {
            if (!sample.detailedNetwork.pruned) {
                ((QueryNeuron)sample.query).evidence = this.reducer.reduce((DetailedNetwork<State.Structure>)sample.detailedNetwork, (QueryNeuron)sample.query);
                sample.detailedNetwork.pruned = true;
            }
            return sample;
        });
    }

    @Override
    protected <T> void export(T outputReady) {
    }

    protected void trueExport() {
        LOG.info("Pruning stats export");
        if (this.exporter != null) {
            this.reducer.finish();
            this.exporter.export(this.reducer);
        }
    }
}

