/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.neural.networks.computation.training;

import cz.cvut.fel.ida.algebra.values.Value;
import cz.cvut.fel.ida.algebra.values.inits.ValueInitializer;
import cz.cvut.fel.ida.algebra.weights.Weight;
import cz.cvut.fel.ida.learning.Model;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.QueryNeuron;
import cz.cvut.fel.ida.setup.Settings;
import java.io.Reader;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import java.util.logging.Logger;
import java.util.stream.Collectors;

public class NeuralModel
implements Model<QueryNeuron> {
    private static final Logger LOG = Logger.getLogger(NeuralModel.class.getName());
    private transient Settings settings;
    public final List<Weight> allWeights;
    public final int maxWeightIndex;
    public final transient List<Weight> learnableWeights;
    public Value threshold;
    public Consumer<Map<Integer, Weight>> templateDebugCallback;

    public NeuralModel(List<Weight> weights, Settings settings) {
        this.settings = settings;
        this.allWeights = weights;
        this.learnableWeights = NeuralModel.filterLearnable(weights);
        if (settings.getOptimizer() == Settings.OptimizerSet.ADAM) {
            this.init4Adam(weights);
        }
        this.maxWeightIndex = this.allWeights.size() - 1;
    }

    public NeuralModel(List<Weight> weights, Consumer<Map<Integer, Weight>> templateUpdateCallback, Settings settings) {
        this(weights, settings);
        if (settings.debugTemplateTraining) {
            this.templateDebugCallback = templateUpdateCallback;
        }
    }

    protected void init4Adam(List<Weight> weights) {
        for (Weight weight : weights) {
            weight.velocity = weight.value.getForm();
            weight.momentum = weight.value.getForm();
        }
    }

    public NeuralModel cloneWeights() {
        List<Weight> clonedWeights = this.allWeights.stream().map(Weight::clone).collect(Collectors.toList());
        NeuralModel clone = new NeuralModel(clonedWeights, this.settings);
        return clone;
    }

    public void resetWeights(ValueInitializer valueInitializer) {
        for (Weight weight : this.allWeights) {
            weight.init(valueInitializer);
        }
    }

    public void loadWeightValues(NeuralModel otherModel) {
        Map<Integer, Weight> otherWeights = otherModel.mapWeightsToIds();
        for (Weight weight : this.allWeights) {
            weight.value = otherWeights.get((Object)Integer.valueOf((int)weight.index)).value;
        }
    }

    public void dropoutWeights() {
    }

    public static List<Weight> filterLearnable(List<Weight> allWeights) {
        return allWeights.stream().filter(Weight::isLearnable).collect(Collectors.toList());
    }

    public Map<Integer, Weight> mapWeightsToIds() {
        return this.allWeights.stream().collect(Collectors.toMap(w -> w.index, w -> w));
    }

    public Map<String, Weight> mapWeightsToNames() {
        return this.allWeights.stream().collect(Collectors.toMap(w -> w.name, w -> w));
    }

    public void importWeights(Reader tensorflow, Map<String, Weight> mapping) {
    }

    @Override
    public String getName() {
        return null;
    }

    @Override
    public Value evaluate(QueryNeuron query) {
        return null;
    }

    @Override
    public List<Weight> getAllWeights() {
        return this.allWeights;
    }
}

