/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.neural.networks.structure.building;

import cz.cvut.fel.ida.algebra.functions.Aggregation;
import cz.cvut.fel.ida.algebra.functions.transformation.joint.AtIndex;
import cz.cvut.fel.ida.algebra.weights.Weight;
import cz.cvut.fel.ida.logic.Literal;
import cz.cvut.fel.ida.logic.constructs.example.ValuedFact;
import cz.cvut.fel.ida.logic.constructs.template.components.BodyAtom;
import cz.cvut.fel.ida.logic.constructs.template.components.GroundHeadRule;
import cz.cvut.fel.ida.logic.constructs.template.components.GroundRule;
import cz.cvut.fel.ida.logic.constructs.template.components.WeightedRule;
import cz.cvut.fel.ida.neural.networks.structure.building.NeuralBuilder;
import cz.cvut.fel.ida.neural.networks.structure.building.NeuronMaps;
import cz.cvut.fel.ida.neural.networks.structure.building.builders.StatesBuilder;
import cz.cvut.fel.ida.neural.networks.structure.components.NeuralSets;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.BaseNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.Neurons;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.WeightedNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.states.State;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.AggregationNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.AtomFact;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.AtomNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.AtomNeurons;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.FactNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.NegationNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.RuleNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.RuleNeurons;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.SplittableAggregationNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.WeightedAtomNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.WeightedRuleNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.types.DetailedNetwork;
import cz.cvut.fel.ida.neural.networks.structure.components.types.TopologicNetwork;
import cz.cvut.fel.ida.neural.networks.structure.metadata.inputMappings.LinkedMapping;
import cz.cvut.fel.ida.neural.networks.structure.metadata.inputMappings.NeuronMapping;
import cz.cvut.fel.ida.neural.networks.structure.metadata.inputMappings.WeightedNeuronMapping;
import cz.cvut.fel.ida.setup.Settings;
import cz.cvut.fel.ida.utils.generic.Pair;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.InputMismatchException;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;
import org.jetbrains.annotations.NotNull;

public class NeuralNetBuilder {
    private static final Logger LOG = Logger.getLogger(NeuralNetBuilder.class.getName());
    public NeuralBuilder neuralBuilder;
    private Settings settings;

    public NeuralNetBuilder(Settings settings, NeuralBuilder neuralBuilder) {
        this.neuralBuilder = neuralBuilder;
        this.settings = settings;
    }

    public NeuralNetBuilder(Settings settings) {
        this.neuralBuilder = new NeuralBuilder(settings);
        this.settings = settings;
    }

    public void loadNeuronsFromRules(Literal head, LinkedHashMap<GroundHeadRule, Collection<GroundRule>> rules, NeuralSets createdNeurons) {
        NeuronMaps neuronMaps = this.neuralBuilder.neuronFactory.neuronMaps;
        boolean newAtomNeuron = false;
        boolean weightedAtomNeuron = false;
        AtomNeurons headAtomNeuron = neuronMaps.atomNeurons.get(head);
        if (headAtomNeuron == null) {
            newAtomNeuron = true;
            Iterator<Map.Entry<GroundHeadRule, Collection<GroundRule>>> iterator = rules.entrySet().iterator();
            Map.Entry<GroundHeadRule, Collection<GroundRule>> liftedRule = null;
            while (iterator.hasNext()) {
                liftedRule = iterator.next();
                if (!head.equals(liftedRule.getValue().iterator().next().groundHead)) {
                    LOG.severe("Ground heads corresponding to the same atom neuron are different!");
                }
                if (liftedRule.getKey().weightedRule.getWeight().equals(Weight.unitWeight)) continue;
                weightedAtomNeuron = true;
            }
            if (weightedAtomNeuron) {
                headAtomNeuron = this.neuralBuilder.neuronFactory.createWeightedAtomNeuron(liftedRule.getKey().weightedRule.getHead(), head);
                createdNeurons.weightedAtomNeurons.add((WeightedAtomNeuron)headAtomNeuron);
            } else {
                headAtomNeuron = this.neuralBuilder.neuronFactory.createUnweightedAtomNeuron(liftedRule.getKey().weightedRule.getHead(), head);
                createdNeurons.atomNeurons.add((AtomNeuron)headAtomNeuron);
            }
            if (headAtomNeuron.getComputationView(0).getFcnState().getInputMask() != null) {
                neuronMaps.containsMasking = true;
            }
        } else {
            headAtomNeuron.setShared(true);
            if (rules.entrySet().size() > 0) {
                NeuronMapping inputMapping;
                if (headAtomNeuron instanceof WeightedNeuron) {
                    weightedAtomNeuron = true;
                    inputMapping = (WeightedNeuronMapping)neuronMaps.extraInputMapping.get(headAtomNeuron);
                    if (inputMapping != null) {
                        neuronMaps.extraInputMapping.put(headAtomNeuron, new WeightedNeuronMapping(inputMapping));
                    } else {
                        neuronMaps.extraInputMapping.put(headAtomNeuron, new WeightedNeuronMapping(headAtomNeuron.getInputs(), ((WeightedNeuron)((Object)headAtomNeuron)).getWeights()));
                    }
                } else {
                    inputMapping = (NeuronMapping)neuronMaps.extraInputMapping.get(headAtomNeuron);
                    if (inputMapping != null) {
                        neuronMaps.extraInputMapping.put(headAtomNeuron, new NeuronMapping(inputMapping));
                    } else {
                        neuronMaps.extraInputMapping.put(headAtomNeuron, new NeuronMapping(headAtomNeuron.getInputs()));
                    }
                }
            }
        }
        for (Map.Entry<GroundHeadRule, Collection<GroundRule>> rules2groundings : rules.entrySet()) {
            NeuronMapping inputMapping;
            BaseNeuron aggInputNeuron;
            AggregationNeuron aggNeuron;
            boolean newAggNeuron = false;
            GroundHeadRule groundHeadRule = rules2groundings.getKey();
            Aggregation aggregation = groundHeadRule.weightedRule.getAggregationFcn();
            if (aggregation != null && aggregation.isSplittable()) {
                Pair<AggregationNeuron, BaseNeuron> aggregagtionNeurons = this.createAggregationNeuron(groundHeadRule, neuronMaps, createdNeurons);
                aggNeuron = (AggregationNeuron)aggregagtionNeurons.r;
                aggInputNeuron = (BaseNeuron)aggregagtionNeurons.s;
            } else {
                aggNeuron = neuronMaps.aggNeurons.get(groundHeadRule);
                if (aggNeuron == null) {
                    newAggNeuron = true;
                    aggNeuron = this.neuralBuilder.neuronFactory.createAggNeuron(groundHeadRule);
                    if (aggNeuron.getComputationView(0).getFcnState().getInputMask() != null) {
                        neuronMaps.containsMasking = true;
                    }
                    createdNeurons.aggNeurons.add(aggNeuron);
                } else {
                    aggNeuron.isShared = true;
                    if (rules2groundings.getValue().size() > 0) {
                        inputMapping = (NeuronMapping)neuronMaps.extraInputMapping.get(aggNeuron);
                        if (inputMapping != null) {
                            neuronMaps.extraInputMapping.put(aggNeuron, new NeuronMapping(inputMapping));
                        } else {
                            neuronMaps.extraInputMapping.put(aggNeuron, new NeuronMapping(aggNeuron.getInputs()));
                        }
                    }
                }
                aggInputNeuron = aggNeuron;
            }
            if (newAtomNeuron) {
                if (weightedAtomNeuron) {
                    ((WeightedNeuron)((Object)headAtomNeuron)).addInput(aggInputNeuron, groundHeadRule.weightedRule.getWeight());
                } else {
                    headAtomNeuron.addInput(aggInputNeuron);
                }
            } else {
                LOG.info("Warning-  modifying previous state - Creating input overmapping for this Atom neuron: " + String.valueOf(headAtomNeuron));
                inputMapping = (WeightedNeuronMapping)neuronMaps.extraInputMapping.get(headAtomNeuron);
                inputMapping.addLink(aggNeuron);
                ((WeightedNeuronMapping)inputMapping).addWeight(rules2groundings.getKey().weightedRule.getWeight());
            }
            if (aggregation != null && aggregation.isSplittable()) continue;
            for (GroundRule grounding : rules2groundings.getValue()) {
                RuleNeurons ruleNeuron = this.createRuleNeuron(grounding, neuronMaps, createdNeurons);
                if (!newAggNeuron) {
                    LOG.info("Warning-  modifying previous state - Creating input overmapping for this Agg neuron: " + String.valueOf(aggNeuron));
                    NeuronMapping inputMapping2 = (NeuronMapping)neuronMaps.extraInputMapping.get(headAtomNeuron);
                    inputMapping2.addLink(ruleNeuron);
                    continue;
                }
                aggNeuron.addInput(ruleNeuron);
            }
        }
    }

    public void loadNeuronsFromFacts(Map<Literal, ValuedFact> groundFacts, NeuralSets createdNeurons) {
        for (Map.Entry<Literal, ValuedFact> factEntry : groundFacts.entrySet()) {
            this.neuralBuilder.neuronFactory.createFactNeuron(factEntry.getValue());
        }
        createdNeurons.factNeurons.addAll(this.neuralBuilder.neuronFactory.neuronMaps.factNeurons.values());
        groundFacts.clear();
    }

    @NotNull
    public void connectAllNeurons(NeuralSets createdNeurons) {
        NeuronMaps neuronMaps = this.neuralBuilder.neuronFactory.neuronMaps;
        for (Map.Entry<GroundRule, RuleNeurons> entry : neuronMaps.ruleNeurons.entrySet()) {
            RuleNeurons ruleNeuron = entry.getValue();
            if (ruleNeuron.inputCount() == entry.getKey().weightedRule.getBody().size()) continue;
            int j = 0;
            for (int i = 0; i < entry.getKey().groundBody.length; ++i) {
                BodyAtom liftedBodyAtom = entry.getKey().weightedRule.getBody().get(j++);
                Literal literal = entry.getKey().groundBody[i];
                if (liftedBodyAtom.isNegated() && liftedBodyAtom.getPredicate() != literal.predicate()) {
                    while (!liftedBodyAtom.getPredicate().name.equals(literal.predicate().name)) {
                        if (j == entry.getKey().weightedRule.getBody().size()) {
                            throw new InputMismatchException("A mismatch between predicates when connecting rule neuron inputs!");
                        }
                        liftedBodyAtom = entry.getKey().weightedRule.getBody().get(j++);
                    }
                }
                Weight weight = liftedBodyAtom.getConjunctWeight();
                AtomFact input = neuronMaps.atomNeurons.get(literal);
                if (input == null) {
                    FactNeuron factNeuron = neuronMaps.factNeurons.get(literal);
                    if (factNeuron == null) {
                        LOG.severe("Error: no input found for this neuron!!: " + String.valueOf(literal));
                        LOG.severe("This is likely due to unstable use of negation in the template...");
                    }
                    input = factNeuron;
                }
                if (liftedBodyAtom.isNegated()) {
                    NegationNeuron negationNeuron = this.neuralBuilder.neuronFactory.createNegationNeuron(input, liftedBodyAtom.getNegationActivation());
                    input = negationNeuron;
                }
                if (ruleNeuron instanceof WeightedNeuron) {
                    ((WeightedNeuron)((Object)ruleNeuron)).addInput(input, weight);
                    continue;
                }
                ((RuleNeuron)ruleNeuron).addInput(input);
            }
        }
    }

    public DetailedNetwork finalizeStoredNetwork(String id, NeuralSets createdNeurons, List<Literal> queryMatchingLiterals) throws RuntimeException {
        ArrayList<AtomNeurons> queryNeurons = null;
        if (queryMatchingLiterals != null) {
            queryNeurons = new ArrayList<AtomNeurons>();
            for (Literal queryMatchingLiteral : queryMatchingLiterals) {
                AtomNeurons qn = this.neuralBuilder.neuronFactory.neuronMaps.atomNeurons.get(queryMatchingLiteral);
                if (qn == null) {
                    String err;
                    if (this.neuralBuilder.neuronFactory.neuronMaps.factNeurons.containsKey(queryMatchingLiteral)) {
                        err = "Quering directly facts, rather than inferred atoms - there is no learning possible for this sample query: " + String.valueOf(queryMatchingLiteral);
                        LOG.severe(err);
                        continue;
                    }
                    if (queryMatchingLiterals.size() > 1) continue;
                    err = "Query: [" + String.valueOf(queryMatchingLiteral) + "] was not matched anywhere in the neural network " + id + " - Cannot calculate its output!";
                    LOG.severe(err);
                    LOG.warning(" -> This most likely means that the template is wrong as there is no proof-path from the example to the query");
                    LOG.warning("   -> Check all the predicate signatures etc. to make sure the template matches your examples and that there is at least 1 inference chain to the query");
                    continue;
                }
                queryNeurons.add(qn);
            }
        }
        DetailedNetwork neuralNetwork = this.neuralBuilder.networkFactory.createDetailedNetwork(queryNeurons, createdNeurons, id, this.neuralBuilder.neuronFactory.neuronMaps.extraInputMapping);
        LOG.fine("DetailedNetwork created.");
        StatesBuilder statesBuilder = this.neuralBuilder.statesBuilder;
        statesBuilder.initializeStates(neuralNetwork);
        LOG.fine("Neuron dimensions inferred.");
        if (this.settings.dropoutRate > 0.0) {
            statesBuilder.setupDropoutStates(neuralNetwork);
        }
        if (this.getNeuronMaps().containsMasking) {
            neuralNetwork.containsInputMasking = true;
        }
        if (neuralNetwork.extraInputMapping != null && !neuralNetwork.extraInputMapping.isEmpty()) {
            statesBuilder.addLinkedInputsToNetworkStates(neuralNetwork);
        }
        if (this.settings.parentCounting || this.settings.neuralNetsPostProcessing) {
            neuralNetwork.outputMapping = this.calculateOutputs(neuralNetwork);
            if (this.settings.parentCounting) {
                statesBuilder.setupParentStateNumbers(neuralNetwork);
            }
        }
        if (this.settings.parallelTraining) {
            int sharedNeuronsCount = statesBuilder.makeSharedStatesRecursively(neuralNetwork);
            LOG.fine("Shared neurons marked.");
            neuralNetwork.setSharedNeuronsCount(sharedNeuronsCount);
        }
        return neuralNetwork;
    }

    public Map<BaseNeuron, LinkedMapping> calculateOutputs(TopologicNetwork<State.Structure> network) {
        HashMap<BaseNeuron, LinkedMapping> outputMapping = new HashMap<BaseNeuron, LinkedMapping>();
        for (BaseNeuron<Neurons, State.Neural> parent : network.allNeuronsTopologic) {
            BaseNeuron child;
            Iterator<Neurons> inputs = network.getInputs(parent);
            while (inputs.hasNext() && (child = (BaseNeuron)inputs.next()) != null) {
                LinkedMapping parentMapping = outputMapping.computeIfAbsent(child, f -> new NeuronMapping());
                parentMapping.addLink(parent);
            }
        }
        return outputMapping;
    }

    private void initSplittableAggregationNeuronIndex(AggregationNeuron aggregationNeuron, Literal groundHead) {
        SplittableAggregationNeuron splitAggNeuron = (SplittableAggregationNeuron)aggregationNeuron;
        BaseNeuron aggInputNeuron = splitAggNeuron.inputOrder.get(groundHead);
        if (aggInputNeuron == null) {
            aggInputNeuron = this.neuralBuilder.neuronFactory.createSplittableAtomNeuron(groundHead, splitAggNeuron);
            splitAggNeuron.inputOrder.put((String)((Object)groundHead), (AtomNeuron)aggInputNeuron);
        }
        ((AtIndex)aggInputNeuron.getTransformation()).setIndex(aggregationNeuron.inputCount() - 1);
    }

    private Pair<AggregationNeuron, BaseNeuron> createAggregationNeuron(GroundHeadRule groundHeadRule, NeuronMaps neuronMaps, NeuralSets createdNeurons) {
        AggregationNeuron aggNeuron;
        Aggregation aggregation = groundHeadRule.weightedRule.getAggregationFcn();
        Literal literal = groundHeadRule.groundHead.maskTerms(aggregation.aggregableTerms());
        GroundHeadRule newGroundHeadRule = new GroundHeadRule(groundHeadRule.weightedRule, literal);
        if (neuronMaps.aggNeurons.get(newGroundHeadRule) != null) {
            aggNeuron = neuronMaps.aggNeurons.get(newGroundHeadRule);
        } else {
            aggNeuron = this.neuralBuilder.neuronFactory.createSplittableAggNeuron(newGroundHeadRule);
            if (aggNeuron.getComputationView(0).getFcnState().getInputMask() != null) {
                neuronMaps.containsMasking = true;
            }
            createdNeurons.aggNeurons.add(aggNeuron);
        }
        SplittableAggregationNeuron splitAggNeuron = (SplittableAggregationNeuron)aggNeuron;
        BaseNeuron aggInputNeuron = splitAggNeuron.inputOrder.get(groundHeadRule.groundHead);
        if (aggInputNeuron == null) {
            aggInputNeuron = this.neuralBuilder.neuronFactory.createSplittableAtomNeuron(groundHeadRule.groundHead, splitAggNeuron);
            splitAggNeuron.inputOrder.put((String)((Object)groundHeadRule.groundHead), (AtomNeuron)aggInputNeuron);
        }
        return new Pair<AggregationNeuron, BaseNeuron>(aggNeuron, aggInputNeuron);
    }

    private RuleNeurons createRuleNeuron(GroundRule grounding, NeuronMaps neuronMaps, NeuralSets createdNeurons) {
        RuleNeurons ruleNeuron = neuronMaps.ruleNeurons.get(grounding);
        if (ruleNeuron != null) {
            LOG.severe("Inconsistency - Specific rule neuron already contained in neuronmap!! This should never happen...");
            return ruleNeuron;
        }
        if (grounding.weightedRule.detectWeights()) {
            ruleNeuron = this.neuralBuilder.neuronFactory.createWeightedRuleNeuron(grounding);
            createdNeurons.weightedRuleNeurons.add((WeightedRuleNeuron)ruleNeuron);
        } else {
            ruleNeuron = this.neuralBuilder.neuronFactory.createRuleNeuron(grounding);
            createdNeurons.ruleNeurons.add((RuleNeuron)ruleNeuron);
        }
        if (ruleNeuron.getComputationView(0).getFcnState().getInputMask() != null) {
            neuronMaps.containsMasking = true;
        }
        return ruleNeuron;
    }

    public void loadSplittableNeuronsFromRules(Literal head, LinkedHashMap<GroundHeadRule, Collection<GroundRule>> rules, NeuralSets createdNeurons) {
        NeuronMaps neuronMaps = this.neuralBuilder.neuronFactory.neuronMaps;
        for (Map.Entry<GroundHeadRule, Collection<GroundRule>> rules2groundings : rules.entrySet()) {
            GroundHeadRule groundHeadRule = rules2groundings.getKey();
            Aggregation aggregation = groundHeadRule.weightedRule.getAggregationFcn();
            if (aggregation == null || !aggregation.isSplittable()) continue;
            HashMap<Literal, List> ruleGroups = new HashMap<Literal, List>();
            GroundHeadRule newGroundHeadRule = new GroundHeadRule(groundHeadRule.weightedRule, head);
            AggregationNeuron aggNeuron = neuronMaps.aggNeurons.get(newGroundHeadRule);
            if (aggNeuron == null) {
                LOG.severe("Warning - splittable aggregation neuron has not been created before connecting rules");
            }
            for (GroundRule groundRule : rules2groundings.getValue()) {
                ruleGroups.computeIfAbsent(groundRule.groundHead, k -> new ArrayList()).add(groundRule);
            }
            for (Map.Entry entry : ruleGroups.entrySet()) {
                List ruleGroundings = (List)entry.getValue();
                Literal literal = (Literal)entry.getKey();
                if (ruleGroundings.size() == 1) {
                    aggNeuron.addInput(this.createRuleNeuron((GroundRule)ruleGroundings.get(0), neuronMaps, createdNeurons));
                } else {
                    GroundHeadRule headRule = new GroundHeadRule(groundHeadRule.weightedRule, new Literal("_" + literal.predicateName(), literal.termList()));
                    headRule.weightedRule = new WeightedRule(headRule.weightedRule);
                    headRule.weightedRule.setAggregationFcn(null);
                    AggregationNeuron aggregationNeuron = this.neuralBuilder.neuronFactory.createAggNeuron(headRule);
                    for (GroundRule grounding : ruleGroundings) {
                        aggregationNeuron.addInput(this.createRuleNeuron(grounding, neuronMaps, createdNeurons));
                    }
                    aggNeuron.addInput(aggregationNeuron);
                }
                this.initSplittableAggregationNeuronIndex(aggNeuron, (Literal)entry.getKey());
            }
        }
    }

    public NeuronMaps getNeuronMaps() {
        return this.neuralBuilder.neuronFactory.neuronMaps;
    }

    public void setNeuronMaps(NeuronMaps neuronMaps) {
        this.neuralBuilder.neuronFactory.neuronMaps = neuronMaps;
    }
}

