/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.neural.networks.structure.transforming;

import cz.cvut.fel.ida.algebra.functions.Transformation;
import cz.cvut.fel.ida.algebra.values.Value;
import cz.cvut.fel.ida.algebra.values.inits.ValueInitializer;
import cz.cvut.fel.ida.algebra.weights.Weight;
import cz.cvut.fel.ida.neural.networks.computation.iteration.actions.Evaluation;
import cz.cvut.fel.ida.neural.networks.computation.iteration.actions.IndependentNeuronProcessing;
import cz.cvut.fel.ida.neural.networks.computation.iteration.visitors.states.neurons.Invalidator;
import cz.cvut.fel.ida.neural.networks.structure.components.NeuralNetwork;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.BaseNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.Neurons;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.QueryNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.WeightedNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.states.State;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.states.States;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.AtomNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.types.DetailedNetwork;
import cz.cvut.fel.ida.neural.networks.structure.transforming.NetworkMerging;
import cz.cvut.fel.ida.neural.networks.structure.transforming.NetworkReducing;
import cz.cvut.fel.ida.setup.Settings;
import cz.cvut.fel.ida.utils.generic.Pair;
import cz.cvut.fel.ida.utils.generic.Timing;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;
import java.util.stream.Collectors;

public class IsoValueNetworkCompressor
implements NetworkReducing,
NetworkMerging {
    private static final Logger LOG = Logger.getLogger(IsoValueNetworkCompressor.class.getName());
    private final transient IndependentNeuronProcessing invalidation;
    private final transient Evaluation evaluation;
    private transient Settings settings;
    private transient ValueInitializer valueInitializer;
    public int repetitions;
    public int decimals;
    Timing timing;
    public int allNeuronCount = 0;
    public int compressedNeuronCount = 0;
    public int preventedByIsoCheck = 0;

    public IsoValueNetworkCompressor(Settings settings) {
        this.settings = settings;
        this.valueInitializer = ValueInitializer.getInitializer(settings);
        this.invalidation = new IndependentNeuronProcessing(settings, new Invalidator(-1));
        this.evaluation = new Evaluation(settings, -1);
        this.repetitions = settings.isoValueInits;
        this.decimals = settings.isoDecimals;
        this.timing = new Timing();
    }

    @Override
    public NeuralNetwork merge(NeuralNetwork a, NeuralNetwork b) {
        return null;
    }

    @Override
    public NeuralNetwork reduce(DetailedNetwork<State.Structure> inet, List<QueryNeuron> outputs) {
        QueryNeuron queryNeuron;
        this.timing.tic();
        if (inet.allNeuronsTopologic.isEmpty()) {
            return inet;
        }
        LinkedHashMap<Neurons, ValueList> isoValues = new LinkedHashMap<Neurons, ValueList>();
        List<Weight> allWeights = inet.getAllWeights();
        Map<Weight, Value> originalValues = allWeights.stream().collect(Collectors.toMap(w -> w, w -> w.value.clone()));
        if (outputs.size() > 1) {
            States.ComputationStateStandard dummyState = new States.ComputationStateStandard(null, Transformation.Singletons.identity);
            dummyState.setValue(Value.ZERO);
            AtomNeuron<States.ComputationStateStandard> dummy = new AtomNeuron<States.ComputationStateStandard>("dummy", -1, dummyState);
            queryNeuron = new QueryNeuron("", -1, 1.0, dummy, inet);
        } else {
            queryNeuron = new QueryNeuron("", -1, 1.0, outputs.get((int)0).neuron, inet);
        }
        int sizeBefore = inet.allNeuronsTopologic.size();
        this.isoIteration(inet, allWeights, queryNeuron, isoValues);
        Map<Neurons, Neurons> etalonMap = this.mergeNeurons(inet, isoValues);
        LinkedHashSet<Neurons> etalons = new LinkedHashSet<Neurons>(etalonMap.values());
        if (outputs.size() > 1) {
            List<Neurons> collect = outputs.stream().map(s -> s.neuron).collect(Collectors.toList());
            NetworkReducing.supervisedNetReconstruction(inet, collect);
        } else {
            NetworkReducing.supervisedNetReconstruction(inet, Collections.singletonList(outputs.get((int)0).neuron));
        }
        this.allNeuronCount += sizeBefore;
        this.compressedNeuronCount += inet.allNeuronsTopologic.size();
        LOG.info("IsoValue neuron compression from " + sizeBefore + " down to " + inet.allNeuronsTopologic.size() + "(etalon values: " + etalons.size() + ")");
        if (etalons.size() > inet.allNeuronsTopologic.size()) {
            LOG.warning("There are more iso-values than neurons after compression (some unique parts have been pruned out!) = lossy compression");
        } else if (!this.settings.structuralIsoCompression && etalons.size() < inet.allNeuronsTopologic.size() - 1) {
            LOG.warning("There are more neurons than iso-values (some neurons have not been pruned despite having the same value) - e.g. output neurons.");
        }
        allWeights.forEach(weight -> {
            weight.value = (Value)originalValues.get(weight);
        });
        this.timing.toc();
        return inet;
    }

    @Override
    public NeuralNetwork reduce(DetailedNetwork<State.Structure> inet, QueryNeuron outputStart) {
        return this.reduce(inet, Arrays.asList(outputStart));
    }

    @Override
    public void finish() {
        this.timing.finish();
    }

    private Map<Neurons, Neurons> mergeNeurons(DetailedNetwork<State.Structure> inet, Map<Neurons, ValueList> isoValues) {
        HashMap<ValueList, List<Neurons>> isoNeurons = new HashMap<ValueList, List<Neurons>>();
        for (Map.Entry<Neurons, ValueList> neuronListEntry : isoValues.entrySet()) {
            Neurons neuron = neuronListEntry.getKey();
            ValueList values = neuronListEntry.getValue();
            List neurons = isoNeurons.computeIfAbsent(values, k -> new ArrayList());
            neurons.add(neuron);
        }
        HashMap<Neurons, Neurons> etalonMap = new HashMap<Neurons, Neurons>();
        for (Map.Entry entry : isoNeurons.entrySet()) {
            Neurons etalon = (Neurons)((List)entry.getValue()).get(0);
            for (Neurons neuron : (List)entry.getValue()) {
                etalonMap.put(neuron, etalon);
            }
        }
        if (this.settings.structuralIsoCompression) {
            this.oversafeCompression(inet, isoValues, isoNeurons, etalonMap);
        } else {
            this.unsafeCompression(inet, etalonMap);
        }
        return etalonMap;
    }

    private void unsafeCompression(DetailedNetwork<State.Structure> inet, Map<Neurons, Neurons> etalonMap) {
        for (BaseNeuron neuron : inet.allNeuronsTopologic) {
            Iterator<Neurons> outputs;
            Neurons etalonReplacement = etalonMap.get(neuron);
            if (etalonReplacement == neuron || (outputs = inet.getOutputs(neuron)) == null) continue;
            while (outputs.hasNext()) {
                Neurons output = outputs.next();
                inet.replaceInput((BaseNeuron)output, neuron, etalonReplacement);
                inet.outputMapping.remove(neuron);
            }
        }
    }

    private void oversafeCompression(DetailedNetwork<State.Structure> inet, Map<Neurons, ValueList> isoValues, Map<ValueList, List<Neurons>> isoNeurons, Map<Neurons, Neurons> etalonMap) {
        for (BaseNeuron neuron : inet.allNeuronsTopologic) {
            Neurons etalonReplacement = etalonMap.get(neuron);
            ValueList valueList = isoValues.get(neuron);
            List<Neurons> equivalentNeurons = isoNeurons.get(valueList);
            if (equivalentNeurons != null && equivalentNeurons.size() > 1) {
                for (Neurons sameNeuron : equivalentNeurons) {
                    Iterator<Neurons> outputs = inet.getOutputs((BaseNeuron)sameNeuron);
                    if (outputs == null) continue;
                    if (!this.equivalent(inet, sameNeuron, etalonReplacement)) {
                        LOG.warning("Trying to replace a neuron with a structurally non-equivalent etalon!");
                        ++this.preventedByIsoCheck;
                        continue;
                    }
                    while (outputs.hasNext()) {
                        Neurons output = outputs.next();
                        inet.replaceInput((BaseNeuron)output, sameNeuron, etalonReplacement);
                        inet.outputMapping.remove(sameNeuron);
                    }
                }
            }
            isoNeurons.remove(valueList);
        }
    }

    public boolean equivalent(DetailedNetwork<State.Structure> inet, Neurons<Neurons, State.Neural> a, Neurons<Neurons, State.Neural> b) {
        if (a.equals(b)) {
            return true;
        }
        if (!((NeuralNetwork)inet).getInputs(a).hasNext() && !((NeuralNetwork)inet).getInputs(b).hasNext()) {
            Value valB;
            Value valA = a.getComputationView(-1).getValue();
            return valA.equals(valB = b.getComputationView(-1).getValue());
        }
        if (a instanceof WeightedNeuron && b instanceof WeightedNeuron) {
            Pair inputsA = inet.getInputs((WeightedNeuron)a);
            ArrayList<Pair<Weight, Neurons>> inputListA = new ArrayList<Pair<Weight, Neurons>>();
            while (((Iterator)inputsA.r).hasNext()) {
                Neurons neuronA = (Neurons)((Iterator)inputsA.r).next();
                Weight weightA = (Weight)((Iterator)inputsA.s).next();
                inputListA.add(new Pair<Weight, Neurons>(weightA, neuronA));
            }
            Pair inputsB = inet.getInputs((WeightedNeuron)b);
            ArrayList<Pair<Weight, Neurons>> inputListB = new ArrayList<Pair<Weight, Neurons>>();
            while (((Iterator)inputsB.r).hasNext()) {
                Neurons neuronB = (Neurons)((Iterator)inputsB.r).next();
                Weight weight = (Weight)((Iterator)inputsB.s).next();
                inputListB.add(new Pair<Weight, Neurons>(weight, neuronB));
            }
            if (inputListA.size() != inputListB.size()) {
                return false;
            }
            for (Pair pair : inputListA) {
                boolean remove = inputListB.remove(pair);
                if (remove) continue;
                return false;
            }
            return true;
        }
        if (a instanceof WeightedNeuron || b instanceof WeightedNeuron) {
            return false;
        }
        Iterator<Neurons> inputsA = ((NeuralNetwork)inet).getInputs(a);
        ArrayList inputListA = new ArrayList();
        inputsA.forEachRemaining(inputListA::add);
        Iterator<Neurons> inputsB = ((NeuralNetwork)inet).getInputs(a);
        ArrayList inputListB = new ArrayList();
        inputsB.forEachRemaining(inputListB::add);
        if (inputListA.size() != inputListB.size()) {
            return false;
        }
        for (Neurons neurons : inputListA) {
            boolean remove = inputListB.remove(neurons);
            if (remove) continue;
            return false;
        }
        return true;
    }

    private void isoIteration(DetailedNetwork<State.Structure> inet, List<Weight> allWeights, QueryNeuron queryNeuron, Map<Neurons, ValueList> isoValues) {
        for (int i = 0; i < this.repetitions; ++i) {
            for (Weight weight : allWeights) {
                weight.init(this.valueInitializer);
            }
            inet.initializeStatesCache(-1);
            this.invalidation.process(inet, queryNeuron.neuron);
            this.evaluation.evaluate(queryNeuron);
            for (BaseNeuron neuron : inet.allNeuronsTopologic) {
                Value value = neuron.getComputationView(-1).getValue();
                ValueList values = isoValues.computeIfAbsent(neuron, k -> new ValueList());
                values.add(value);
            }
        }
    }

    private class ValueList {
        int length;
        Value[] values;
        int index = 0;
        int hashCode = -1;

        public ValueList() {
            this.length = IsoValueNetworkCompressor.this.repetitions;
            this.values = new Value[this.length];
        }

        public void add(Value value) {
            this.values[this.index++] = this.roundUp(value);
        }

        public Value roundUp(Value value) {
            Value clone = value.getForm();
            Iterator iterator = value.iterator();
            int i = 0;
            while (iterator.hasNext()) {
                Double next = (Double)iterator.next();
                BigDecimal bigDecimal = new BigDecimal(next).setScale(IsoValueNetworkCompressor.this.decimals, 4);
                clone.set(i, bigDecimal.doubleValue());
                ++i;
            }
            return clone;
        }

        public int hashCode() {
            if (this.hashCode != -1) {
                return this.hashCode;
            }
            this.hashCode = 1;
            for (int i = 0; i < this.values.length; ++i) {
                this.hashCode = 31 * this.hashCode + this.values[i].hashCode();
            }
            return this.hashCode;
        }

        public boolean equals(Object obj) {
            if (obj instanceof ValueList) {
                ValueList valueList = (ValueList)obj;
                for (int i = 0; i < this.length; ++i) {
                    if (this.values[i].equals(valueList.values[i])) continue;
                    return false;
                }
            } else {
                return false;
            }
            return true;
        }
    }
}

