/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.neural.networks.structure.building;

import cz.cvut.fel.ida.algebra.functions.Aggregation;
import cz.cvut.fel.ida.algebra.functions.Combination;
import cz.cvut.fel.ida.algebra.functions.Transformation;
import cz.cvut.fel.ida.algebra.functions.transformation.joint.AtIndex;
import cz.cvut.fel.ida.algebra.values.ScalarValue;
import cz.cvut.fel.ida.algebra.values.Value;
import cz.cvut.fel.ida.algebra.weights.Weight;
import cz.cvut.fel.ida.logic.Literal;
import cz.cvut.fel.ida.logic.Predicate;
import cz.cvut.fel.ida.logic.constructs.building.factories.WeightFactory;
import cz.cvut.fel.ida.logic.constructs.example.ValuedFact;
import cz.cvut.fel.ida.logic.constructs.template.components.GroundHeadRule;
import cz.cvut.fel.ida.logic.constructs.template.components.GroundRule;
import cz.cvut.fel.ida.logic.constructs.template.components.HeadAtom;
import cz.cvut.fel.ida.logic.constructs.template.components.WeightedRule;
import cz.cvut.fel.ida.neural.networks.structure.building.NeuronMaps;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.states.State;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.states.States;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.AggregationNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.AtomFact;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.AtomNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.FactNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.NegationNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.RuleNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.SplittableAggregationNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.WeightedAtomNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.WeightedRuleNeuron;
import cz.cvut.fel.ida.setup.Settings;
import java.util.logging.Logger;

public class NeuronFactory {
    private static final Logger LOG = Logger.getLogger(NeuronFactory.class.getName());
    private WeightFactory weightFactory;
    Settings settings;
    private int counter = 0;
    private Weight atomOffset;
    private Weight ruleOffset;
    private Value defaultFactValue;
    public NeuronMaps neuronMaps;

    public NeuronFactory(WeightFactory weightFactory, Settings settings) {
        this.weightFactory = weightFactory;
        this.settings = settings;
        this.atomOffset = new Weight(-10, "fixedAtomOffset", new ScalarValue(settings.defaultAtomNeuronOffset), true, true);
        this.ruleOffset = new Weight(-9, "fixedRuleOffset", new ScalarValue(settings.defaultRuleNeuronOffset), true, true);
        this.defaultFactValue = new ScalarValue(settings.defaultFactValue);
    }

    public WeightedAtomNeuron createWeightedAtomNeuron(HeadAtom head, Literal groundHead) {
        Combination combination = head.getCombination() != null ? head.getCombination() : Combination.getFunction(this.settings.atomNeuronCombination);
        Transformation transformation = head.getTransformation() != null ? head.getTransformation() : Transformation.getFunction(this.settings.atomNeuronTransformation);
        State.Neural.Computation state = State.createBaseState(this.settings, combination, transformation);
        Weight offset = head.getOffset();
        if (offset == null) {
            if (this.settings.defaultAtomOffsetsLearnable) {
                offset = this.settings.defaultAtomNeuronOffset != 0.0 ? this.weightFactory.construct(new ScalarValue(this.settings.defaultAtomNeuronOffset), false, true) : this.weightFactory.construct(new ScalarValue(this.settings.defaultAtomNeuronOffset), false, false);
            } else if (this.settings.defaultAtomNeuronOffset != 0.0) {
                offset = this.atomOffset;
            }
        }
        WeightedAtomNeuron<State.Neural.Computation> atomNeuron = new WeightedAtomNeuron<State.Neural.Computation>(groundHead.toString(), offset, this.counter++, state);
        this.neuronMaps.atomNeurons.put(groundHead, atomNeuron);
        LOG.finest(() -> "Created atom neuron: " + String.valueOf(atomNeuron));
        return atomNeuron;
    }

    public AtomNeuron createUnweightedAtomNeuron(HeadAtom head, Literal groundHead) {
        Combination combination = head.getCombination() != null ? head.getCombination() : Combination.getFunction(this.settings.atomNeuronCombination);
        Transformation transformation = head.getTransformation() != null ? head.getTransformation() : Transformation.getFunction(this.settings.atomNeuronTransformation);
        State.Neural.Computation state = State.createBaseState(this.settings, combination, transformation);
        AtomNeuron<State.Neural.Computation> atomNeuron = new AtomNeuron<State.Neural.Computation>(groundHead.toString(), this.counter++, state);
        this.neuronMaps.atomNeurons.put(groundHead, atomNeuron);
        LOG.finest(() -> "Created atom neuron: " + String.valueOf(atomNeuron));
        return atomNeuron;
    }

    public AggregationNeuron createAggNeuron(GroundHeadRule groundHeadRule) {
        WeightedRule weightedRule = groundHeadRule.weightedRule;
        Aggregation aggregation = weightedRule.getAggregationFcn() != null ? weightedRule.getAggregationFcn() : Aggregation.getFunction(this.settings.aggNeuronAggregation);
        State.Neural.Computation state = State.createBaseState(this.settings, aggregation, null);
        AggregationNeuron<State.Neural.Computation> aggregationNeuron = new AggregationNeuron<State.Neural.Computation>(this.settings.fullAggNeuronStrings ? groundHeadRule.toFullString() : weightedRule.getOriginalString(), this.counter++, state);
        this.neuronMaps.aggNeurons.put(groundHeadRule, aggregationNeuron);
        LOG.finest(() -> "Created aggregation neuron: " + String.valueOf(aggregationNeuron));
        return aggregationNeuron;
    }

    public SplittableAggregationNeuron createSplittableAggNeuron(GroundHeadRule groundHeadRule) {
        WeightedRule weightedRule = groundHeadRule.weightedRule;
        Aggregation aggregation = weightedRule.getAggregationFcn() != null ? weightedRule.getAggregationFcn() : Aggregation.getFunction(this.settings.aggNeuronAggregation);
        State.Neural.Computation state = State.createBaseState(this.settings, aggregation, null);
        SplittableAggregationNeuron<State.Neural.Computation> aggregationNeuron = new SplittableAggregationNeuron<State.Neural.Computation>(this.settings.fullAggNeuronStrings ? groundHeadRule.toFullString() : weightedRule.getOriginalString(), this.counter++, state);
        this.neuronMaps.aggNeurons.put(groundHeadRule, aggregationNeuron);
        LOG.finest(() -> "Created splittable aggregation neuron: " + String.valueOf(aggregationNeuron));
        return aggregationNeuron;
    }

    public AtomNeuron createSplittableAtomNeuron(Literal groundHead, SplittableAggregationNeuron splittableAggregationNeuron) {
        Combination combination = Combination.getFunction(this.settings.atomNeuronCombination);
        AtIndex transformation = new AtIndex();
        State.Neural.Computation state = State.createBaseState(this.settings, combination, transformation);
        Literal head = new Literal(new Predicate("_" + groundHead.predicate().name, groundHead.predicate().arity), groundHead.isNegated(), groundHead.termList());
        AtomNeuron<State.Neural.Computation> atomNeuron = new AtomNeuron<State.Neural.Computation>(head.toString(), this.counter++, state);
        this.neuronMaps.atomNeurons.put(head, atomNeuron);
        atomNeuron.addInput(splittableAggregationNeuron);
        LOG.finest(() -> "Created splittable atom neuron: " + String.valueOf(atomNeuron));
        return atomNeuron;
    }

    public RuleNeuron createRuleNeuron(GroundRule groundRule) {
        WeightedRule weightedRule = groundRule.weightedRule;
        Combination combination = weightedRule.getCombination() != null ? weightedRule.getCombination() : Combination.getFunction(this.settings.ruleNeuronCombination);
        Transformation transformation = weightedRule.getTransformation() != null ? weightedRule.getTransformation() : Transformation.getFunction(this.settings.ruleNeuronTransformation);
        State.Neural.Computation state = State.createBaseState(this.settings, combination, transformation);
        RuleNeuron<State.Neural.Computation> ruleNeuron = new RuleNeuron<State.Neural.Computation>(this.settings.fullRuleNeuronStrings ? groundRule.toFullString() : weightedRule.getOriginalString(), this.counter++, state);
        this.neuronMaps.ruleNeurons.put(groundRule, ruleNeuron);
        LOG.finest(() -> "Created rule neuron: " + String.valueOf(ruleNeuron));
        return ruleNeuron;
    }

    public WeightedRuleNeuron createWeightedRuleNeuron(GroundRule groundRule) {
        WeightedRule weightedRule = groundRule.weightedRule;
        Combination combination = weightedRule.getCombination() != null ? weightedRule.getCombination() : Combination.getFunction(this.settings.ruleNeuronCombination);
        Transformation transformation = weightedRule.getTransformation() != null ? weightedRule.getTransformation() : Transformation.getFunction(this.settings.ruleNeuronTransformation);
        Weight offset = weightedRule.getOffset();
        if (offset == null) {
            if (this.settings.defaultRuleOffsetsLearnable) {
                offset = this.settings.defaultRuleNeuronOffset != 0.0 ? this.weightFactory.construct(new ScalarValue(this.settings.defaultRuleNeuronOffset), false, true) : this.weightFactory.construct(new ScalarValue(this.settings.defaultRuleNeuronOffset), false, false);
            } else if (this.settings.defaultRuleNeuronOffset != 0.0) {
                offset = this.atomOffset;
            }
        }
        State.Neural.Computation state = State.createBaseState(this.settings, combination, transformation);
        WeightedRuleNeuron<State.Neural.Computation> weightedRuleNeuron = new WeightedRuleNeuron<State.Neural.Computation>(this.settings.fullRuleNeuronStrings ? groundRule.toFullString() : weightedRule.getOriginalString(), offset, this.counter++, state);
        this.neuronMaps.ruleNeurons.put(groundRule, weightedRuleNeuron);
        LOG.finest(() -> "Created weightedRule neuron: " + String.valueOf(weightedRuleNeuron));
        return weightedRuleNeuron;
    }

    public FactNeuron createFactNeuron(ValuedFact fact) {
        FactNeuron result = this.neuronMaps.factNeurons.get(fact.literal);
        if (result == null) {
            States.SimpleValue simpleValue = new States.SimpleValue(fact.getValue() == null ? this.defaultFactValue : fact.getValue());
            FactNeuron factNeuron = new FactNeuron(fact.originalString, fact.weight, this.counter++, simpleValue);
            if (fact.weight != null && fact.weight.isLearnable()) {
                factNeuron.hasLearnableValue = true;
                simpleValue.isLearnable = true;
            }
            this.neuronMaps.factNeurons.put(fact.literal, factNeuron);
            LOG.finest(() -> "Created fact neuron: " + String.valueOf(factNeuron));
            return factNeuron;
        }
        return result;
    }

    public NegationNeuron createNegationNeuron(AtomFact atomFact, Transformation negation) {
        Transformation transformation = negation != null ? negation : Transformation.getFunction(this.settings.softNegation);
        State.Neural.Computation state = State.createBaseState(this.settings, null, transformation);
        NegationNeuron<State.Neural.Computation> negationNeuron = new NegationNeuron<State.Neural.Computation>(atomFact, this.counter++, state);
        this.neuronMaps.negationNeurons.add(negationNeuron);
        LOG.finest(() -> "Created negation neuron: " + String.valueOf(negationNeuron));
        return negationNeuron;
    }
}

