/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.neural.networks.structure.building;

import cz.cvut.fel.ida.algebra.functions.Aggregation;
import cz.cvut.fel.ida.algebra.values.ScalarValue;
import cz.cvut.fel.ida.algebra.values.Value;
import cz.cvut.fel.ida.logic.Clause;
import cz.cvut.fel.ida.logic.Literal;
import cz.cvut.fel.ida.logic.constructs.building.factories.WeightFactory;
import cz.cvut.fel.ida.logic.constructs.example.LogicSample;
import cz.cvut.fel.ida.logic.constructs.example.QueryAtom;
import cz.cvut.fel.ida.logic.constructs.template.components.GroundHeadRule;
import cz.cvut.fel.ida.logic.constructs.template.components.GroundRule;
import cz.cvut.fel.ida.logic.grounding.GroundTemplate;
import cz.cvut.fel.ida.logic.grounding.GroundingSample;
import cz.cvut.fel.ida.logic.subsumption.Matching;
import cz.cvut.fel.ida.neural.networks.structure.building.NeuralNetBuilder;
import cz.cvut.fel.ida.neural.networks.structure.building.NeuralProcessingSample;
import cz.cvut.fel.ida.neural.networks.structure.building.NeuronMaps;
import cz.cvut.fel.ida.neural.networks.structure.components.NeuralNetwork;
import cz.cvut.fel.ida.neural.networks.structure.components.NeuralSets;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.QueryNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.AtomNeurons;
import cz.cvut.fel.ida.neural.networks.structure.components.types.DetailedNetwork;
import cz.cvut.fel.ida.setup.Settings;
import cz.cvut.fel.ida.utils.exporting.Exportable;
import cz.cvut.fel.ida.utils.generic.Timing;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.InputMismatchException;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import org.jetbrains.annotations.NotNull;

public class Neuralizer
implements Exportable {
    private static final Logger LOG = Logger.getLogger(Neuralizer.class.getName());
    private transient Settings settings;
    public transient NeuralNetBuilder neuralNetBuilder;
    public NeuralSets.NeuronCounter neuronCounts;
    int queryNeuronsCreated;
    int groundRulesProcessed;
    int networksCreated;
    public Timing timing;

    public Neuralizer(Settings settings) {
        this.settings = settings;
        this.neuralNetBuilder = new NeuralNetBuilder(settings);
        this.timing = new Timing();
    }

    public Neuralizer(Settings settings, WeightFactory weightFactory) {
        this(settings);
        this.neuralNetBuilder.neuralBuilder.weightFactory = weightFactory;
    }

    public List<NeuralProcessingSample> neuralize(GroundTemplate groundTemplate, List<GroundingSample> samples) throws RuntimeException {
        DetailedNetwork neuralNetwork;
        this.timing.tic();
        ++this.networksCreated;
        GroundingSample groundingSample = samples.get(0);
        NeuronMaps neuronMaps = (NeuronMaps)groundingSample.groundingWrap.getNeuronMaps();
        if (neuronMaps == null) {
            NeuronMaps finalNeuronMaps = neuronMaps = new NeuronMaps(groundingSample.groundingWrap.getGroundTemplate().groundRules, groundingSample.groundingWrap.getGroundTemplate().groundFacts);
            samples.forEach(s -> s.groundingWrap.setNeuronMaps(finalNeuronMaps));
        }
        this.neuralNetBuilder.setNeuronMaps(neuronMaps);
        NeuralSets createdNeurons = new NeuralSets();
        ArrayList<Literal> queryExpandedLiterals = new ArrayList<Literal>();
        ArrayList<LogicSample> origSamples = new ArrayList<LogicSample>();
        boolean noMatch = true;
        for (LogicSample logicSample : samples) {
            List<Literal> foundQueries = this.getQueryMatchingLiterals((QueryAtom)logicSample.query, groundTemplate);
            if (foundQueries.isEmpty()) {
                String err = "Query [" + String.valueOf(((QueryAtom)logicSample.query).headAtom) + "] not matched anywhere in the template!";
                LOG.warning(err);
                if (((QueryAtom)logicSample.query).headAtom == null) continue;
                queryExpandedLiterals.add(((QueryAtom)logicSample.query).headAtom.literal);
                origSamples.add(logicSample);
                continue;
            }
            noMatch = false;
            for (Literal foundQuery : foundQueries) {
                if (foundQuery == null) {
                    throw new InputMismatchException("Null query matched for this sample: " + String.valueOf(logicSample));
                }
                queryExpandedLiterals.add(foundQuery);
                origSamples.add(logicSample);
            }
        }
        if (noMatch) {
            if (((QueryAtom)groundingSample.query).headAtom == null) {
                neuralNetwork = this.blindNeuralization(groundingSample.groundingWrap.getGroundTemplate(), neuronMaps, createdNeurons);
                NeuralProcessingSample neuralProcessingSample = new NeuralProcessingSample(new ScalarValue(0.0), new QueryNeuron(groundingSample.getId(), 0, 0.0, null, neuralNetwork), groundingSample.type, this.settings);
                return Collections.singletonList(neuralProcessingSample);
            }
            String err = "Not a single query was matched anywhere in the template for any of: " + String.valueOf(samples);
            LOG.severe(err);
        }
        if (this.settings.forceFullNetworks) {
            neuralNetwork = this.blindNeuralization(groundTemplate, neuronMaps, createdNeurons);
        } else {
            this.neuralNetBuilder = this.loadAllNeuronsStartingFromQueryLiterals(groundTemplate, queryExpandedLiterals, neuronMaps, createdNeurons);
            neuralNetwork = this.getDetailedNetwork(neuronMaps, createdNeurons, groundTemplate, queryExpandedLiterals);
        }
        ArrayList<NeuralProcessingSample> arrayList = new ArrayList<NeuralProcessingSample>();
        noMatch = true;
        for (int i = 0; i < queryExpandedLiterals.size(); ++i) {
            LogicSample logicSample = (LogicSample)origSamples.get(i);
            QueryAtom queryAtom = (QueryAtom)logicSample.query;
            AtomNeurons atomNeuron = this.neuralNetBuilder.getNeuronMaps().atomNeurons.get(queryExpandedLiterals.get(i));
            if (atomNeuron == null) {
                if (queryExpandedLiterals.size() <= 1) {
                    LOG.severe("No neural inference network created for " + queryAtom.toString());
                } else if (logicSample.target.greaterThan(Value.ZERO) && this.settings.trainOnlineResultsType != Settings.ResultsType.REGRESSION) {
                    LOG.warning("Unable to infer a positively labeled sample " + String.valueOf(logicSample));
                }
            } else {
                noMatch = false;
            }
            QueryNeuron queryNeuron = new QueryNeuron(queryAtom.ID + ":" + queryAtom.headAtom.toString(), queryAtom.position, queryAtom.importance, atomNeuron, neuralNetwork);
            NeuralProcessingSample neuralProcessingSample = new NeuralProcessingSample(logicSample.target, queryNeuron, logicSample.type, this.settings);
            arrayList.add(neuralProcessingSample);
        }
        if (noMatch) {
            String err = "No neural inference network created for any query of the sample " + String.valueOf(arrayList);
            LOG.severe(err);
        }
        this.neuronCounts = createdNeurons.getCounts();
        this.timing.toc();
        return arrayList;
    }

    public List<NeuralProcessingSample> neuralize(GroundingSample groundingSample) throws RuntimeException {
        this.timing.tic();
        ++this.networksCreated;
        NeuronMaps neuronMaps = (NeuronMaps)groundingSample.groundingWrap.getNeuronMaps();
        if (neuronMaps == null) {
            neuronMaps = new NeuronMaps(groundingSample.groundingWrap.getGroundTemplate().groundRules, groundingSample.groundingWrap.getGroundTemplate().groundFacts);
            groundingSample.groundingWrap.setNeuronMaps(neuronMaps);
        }
        this.neuralNetBuilder.setNeuronMaps(neuronMaps);
        NeuralSets createdNeurons = new NeuralSets();
        if (((QueryAtom)groundingSample.query).headAtom == null) {
            DetailedNetwork neuralNetwork = this.blindNeuralization(groundingSample.groundingWrap.getGroundTemplate(), neuronMaps, createdNeurons);
            NeuralProcessingSample neuralProcessingSample = new NeuralProcessingSample(new ScalarValue(0.0), new QueryNeuron(groundingSample.getId(), 0, 0.0, null, neuralNetwork), groundingSample.type, this.settings);
            return Collections.singletonList(neuralProcessingSample);
        }
        List<QueryNeuron> queryNeurons = this.supervisedNeuralization(groundingSample, neuronMaps, createdNeurons);
        this.queryNeuronsCreated += queryNeurons.size();
        if (queryNeurons.isEmpty()) {
            LOG.severe("No inference network created for " + String.valueOf(groundingSample.query));
        }
        List<NeuralProcessingSample> samples = queryNeurons.stream().map(queryNeuron -> new NeuralProcessingSample(groundingSample.target, (QueryNeuron)queryNeuron, groundingSample.type, this.settings)).collect(Collectors.toList());
        this.neuronCounts = createdNeurons.getCounts();
        this.timing.toc();
        return samples;
    }

    private List<QueryNeuron> supervisedNeuralization(GroundingSample groundingSample, NeuronMaps neuronMaps, NeuralSets createdNeurons) throws RuntimeException {
        DetailedNetwork neuralNetwork;
        QueryAtom queryAtom = (QueryAtom)groundingSample.query;
        GroundTemplate groundTemplate = groundingSample.groundingWrap.getGroundTemplate();
        List<Literal> queryMatchingLiterals = this.getQueryMatchingLiterals(queryAtom, groundTemplate);
        if (queryMatchingLiterals.isEmpty()) {
            String err = "Query [" + String.valueOf(queryAtom.headAtom) + "] not matched anywhere in the template. Cannot perform neural training.";
            LOG.severe(err);
            if (queryAtom.headAtom == null || !this.settings.queriesAlignedWithExamples) {
                return new ArrayList<QueryNeuron>();
            }
            throw new RuntimeException(err);
        }
        LOG.finer("Obtained QueryMatchingLiterals: " + String.valueOf(queryMatchingLiterals));
        if (this.settings.forceFullNetworks || queryMatchingLiterals.isEmpty()) {
            neuralNetwork = this.blindNeuralization(groundTemplate, neuronMaps, createdNeurons);
        } else {
            this.neuralNetBuilder = this.loadAllNeuronsStartingFromQueryLiterals(groundTemplate, queryMatchingLiterals, neuronMaps, createdNeurons);
            neuralNetwork = this.getDetailedNetwork(neuronMaps, createdNeurons, groundTemplate, queryMatchingLiterals);
        }
        return this.getQueryNeurons(queryAtom, this.neuralNetBuilder.getNeuronMaps(), neuralNetwork, queryMatchingLiterals);
    }

    private DetailedNetwork blindNeuralization(GroundTemplate groundTemplate, NeuronMaps neuronMaps, NeuralSets currentNeuralSets) throws RuntimeException {
        for (Map.Entry<Literal, LinkedHashMap<GroundHeadRule, Collection<GroundRule>>> entry : neuronMaps.groundRules.entrySet()) {
            this.neuralNetBuilder.loadNeuronsFromRules(entry.getKey(), entry.getValue(), currentNeuralSets);
        }
        neuronMaps.groundRules.clear();
        return this.getDetailedNetwork(neuronMaps, currentNeuralSets, groundTemplate, null);
    }

    private DetailedNetwork getDetailedNetwork(NeuronMaps neuronMaps, NeuralSets createdNeurons, GroundTemplate groundTemplate, List<Literal> queryMatchingLiterals) throws RuntimeException {
        if (this.neuralNetBuilder.neuralBuilder.neuronFactory.neuronMaps.factNeurons.isEmpty() || this.settings.groundingMode == Settings.GroundingMode.SEQUENTIAL) {
            this.neuralNetBuilder.loadNeuronsFromFacts(neuronMaps.groundFacts, createdNeurons);
        }
        LOG.fine("Neurons created: " + String.valueOf(this.neuralNetBuilder.getNeuronMaps()));
        this.neuralNetBuilder.connectAllNeurons(createdNeurons);
        LOG.fine("All neurons connected.");
        DetailedNetwork neuralNetwork = this.neuralNetBuilder.finalizeStoredNetwork(groundTemplate.getName(), createdNeurons, queryMatchingLiterals);
        LOG.fine("Final neural network created: " + String.valueOf(neuralNetwork));
        return neuralNetwork;
    }

    private void recursiveNeuronsCreation(@NotNull Literal literal, Set<Literal> closedSet, NeuronMaps neuronMaps, NeuralSets currentNeuralSets, boolean splittable) {
        if (closedSet.contains(literal)) {
            return;
        }
        closedSet.add(literal);
        LinkedHashMap ruleMap = (LinkedHashMap)neuronMaps.groundRules.remove(literal);
        if (ruleMap != null) {
            if (splittable) {
                this.neuralNetBuilder.loadSplittableNeuronsFromRules(literal, ruleMap, currentNeuralSets);
            } else {
                this.neuralNetBuilder.loadNeuronsFromRules(literal, ruleMap, currentNeuralSets);
            }
            ++this.groundRulesProcessed;
            for (Map.Entry entry : ruleMap.entrySet()) {
                Aggregation aggregation = ((GroundHeadRule)entry.getKey()).weightedRule.getAggregationFcn();
                if (!splittable && aggregation != null && aggregation.isSplittable()) {
                    Literal maskedLiteral = ((GroundHeadRule)entry.getKey()).groundHead.maskTerms(aggregation.aggregableTerms());
                    this.recursiveNeuronsCreation(maskedLiteral, closedSet, neuronMaps, currentNeuralSets, true);
                    continue;
                }
                for (GroundRule grounding : (Collection)entry.getValue()) {
                    for (Literal bodyAtom : grounding.groundBody) {
                        if (bodyAtom == null) {
                            throw new RuntimeException("Encoutered a null ground body atom in " + String.valueOf(grounding));
                        }
                        this.recursiveNeuronsCreation(bodyAtom, closedSet, neuronMaps, currentNeuralSets, false);
                    }
                }
            }
        }
    }

    @NotNull
    protected List<Literal> getQueryMatchingLiterals(QueryAtom queryAtom, @NotNull GroundTemplate groundTemplate) {
        if (queryAtom.headAtom == null) {
            return new ArrayList<Literal>(0);
        }
        Literal queryLiteral = queryAtom.headAtom.literal;
        if (!queryLiteral.containsVariable()) {
            ArrayList<Literal> queries = new ArrayList<Literal>();
            if (groundTemplate.groundRules.containsKey(queryLiteral)) {
                queries.add(queryLiteral);
            } else if (groundTemplate.groundFacts.containsKey(queryLiteral)) {
                LOG.severe("Quering directly facts with " + String.valueOf(queryLiteral));
                queries.add(queryLiteral);
            }
            return queries;
        }
        Matching matching = new Matching();
        ArrayList<Literal> queryLiterals = new ArrayList<Literal>();
        for (Map.Entry<Literal, LinkedHashMap<GroundHeadRule, Collection<GroundRule>>> entry : groundTemplate.groundRules.entrySet()) {
            if (!queryAtom.headAtom.literal.predicate().equals(entry.getKey().predicate()) || !matching.subsumption(new Clause(queryAtom.headAtom.literal), new Clause(entry.getKey())).booleanValue()) continue;
            queryLiterals.add(entry.getKey());
        }
        return queryLiterals;
    }

    protected NeuralNetBuilder loadAllNeuronsStartingFromQueryLiterals(GroundTemplate groundTemplate, List<Literal> queryLiterals, NeuronMaps neuronMaps, NeuralSets currentNeuralSets) {
        HashSet<Literal> closedSet = new HashSet<Literal>();
        for (Literal queryLiteral : queryLiterals) {
            this.recursiveNeuronsCreation(queryLiteral, closedSet, neuronMaps, currentNeuralSets, false);
            closedSet.add(queryLiteral);
        }
        return this.neuralNetBuilder;
    }

    @NotNull
    protected List<QueryNeuron> getQueryNeurons(QueryAtom queryAtom, NeuronMaps neuronMaps, NeuralNetwork neuralNetwork, List<Literal> queryMatchingLiterals) {
        ArrayList<QueryNeuron> queryNeurons = new ArrayList<QueryNeuron>();
        for (Literal queryLiteral : queryMatchingLiterals) {
            AtomNeurons atomNeuron = neuronMaps.atomNeurons.get(queryLiteral);
            if (atomNeuron == null) {
                LOG.severe("Query not matched!");
            }
            QueryNeuron queryNeuron = new QueryNeuron(queryAtom.ID + ":" + queryAtom.headAtom.toString(), queryAtom.position, queryAtom.importance, atomNeuron, neuralNetwork);
            queryNeurons.add(queryNeuron);
        }
        return queryNeurons;
    }
}

