/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.application.learning.weight;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.linqs.psl.application.ModelApplication;
import org.linqs.psl.application.groundrulestore.GroundRuleStore;
import org.linqs.psl.application.learning.weight.TrainingMap;
import org.linqs.psl.application.util.Grounding;
import org.linqs.psl.config.ConfigBundle;
import org.linqs.psl.database.Database;
import org.linqs.psl.database.atom.AtomManager;
import org.linqs.psl.database.atom.PersistedAtomManager;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.atom.ObservedAtom;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.predicate.StandardPredicate;
import org.linqs.psl.model.rule.GroundRule;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.WeightedGroundRule;
import org.linqs.psl.model.rule.WeightedRule;
import org.linqs.psl.model.rule.misc.GroundValueConstraint;
import org.linqs.psl.reasoner.Reasoner;
import org.linqs.psl.reasoner.term.TermGenerator;
import org.linqs.psl.reasoner.term.TermStore;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class WeightLearningApplication
implements ModelApplication {
    private static final Logger log = LoggerFactory.getLogger(WeightLearningApplication.class);
    public static final String CONFIG_PREFIX = "weightlearning";
    public static final String REASONER_KEY = "weightlearning.reasoner";
    public static final String REASONER_DEFAULT = "org.linqs.psl.reasoner.admm.ADMMReasoner";
    public static final String GROUND_RULE_STORE_KEY = "weightlearning.groundrulestore";
    public static final String GROUND_RULE_STORE_DEFAULT = "org.linqs.psl.application.groundrulestore.MemoryGroundRuleStore";
    public static final String TERM_STORE_KEY = "weightlearning.termstore";
    public static final String TERM_STORE_DEFAULT = "org.linqs.psl.reasoner.admm.term.ADMMTermStore";
    public static final String TERM_GENERATOR_KEY = "weightlearning.termgenerator";
    public static final String TERM_GENERATOR_DEFAULT = "org.linqs.psl.reasoner.admm.term.ADMMTermGenerator";
    protected ConfigBundle config;
    protected boolean supportsLatentVariables;
    protected Database rvDB;
    protected Database observedDB;
    protected PersistedAtomManager atomManager;
    protected List<Rule> allRules;
    protected List<WeightedRule> mutableRules;
    protected double[] observedIncompatibility;
    protected double[] expectedIncompatibility;
    protected TrainingMap trainingMap;
    protected Reasoner reasoner;
    protected GroundRuleStore groundRuleStore;
    protected GroundRuleStore latentGroundRuleStore;
    protected TermGenerator termGenerator;
    protected TermStore termStore;
    protected TermStore latentTermStore;

    public WeightLearningApplication(List<Rule> rules, Database rvDB, Database observedDB, boolean supportsLatentVariables, ConfigBundle config) {
        this.rvDB = rvDB;
        this.observedDB = observedDB;
        this.supportsLatentVariables = supportsLatentVariables;
        this.config = config;
        this.allRules = new ArrayList<Rule>();
        this.mutableRules = new ArrayList<WeightedRule>();
        for (Rule rule : rules) {
            this.allRules.add(rule);
            if (!(rule instanceof WeightedRule)) continue;
            this.mutableRules.add((WeightedRule)rule);
        }
        this.observedIncompatibility = new double[this.mutableRules.size()];
        this.expectedIncompatibility = new double[this.mutableRules.size()];
    }

    public void learn() {
        this.initGroundModel();
        if (this.supportsLatentVariables) {
            this.initLatentGroundModel();
        }
        this.doLearn();
    }

    protected abstract void doLearn();

    protected void initGroundModel() {
        this.reasoner = (Reasoner)this.config.getNewObject(REASONER_KEY, REASONER_DEFAULT);
        this.groundRuleStore = (GroundRuleStore)this.config.getNewObject(GROUND_RULE_STORE_KEY, GROUND_RULE_STORE_DEFAULT);
        this.termStore = (TermStore)this.config.getNewObject(TERM_STORE_KEY, TERM_STORE_DEFAULT);
        this.termGenerator = (TermGenerator)this.config.getNewObject(TERM_GENERATOR_KEY, TERM_GENERATOR_DEFAULT);
        this.atomManager = this.createAtomManager();
        this.ensureTargets();
        this.trainingMap = new TrainingMap(this.atomManager, this.observedDB);
        if (!this.supportsLatentVariables && this.trainingMap.getLatentVariables().size() > 0) {
            Set<RandomVariableAtom> latentVariables = this.trainingMap.getLatentVariables();
            throw new IllegalArgumentException(String.format("All RandomVariableAtoms must have corresponding ObservedAtoms, found %d latent variables. Latent variables are not supported by this WeightLearningApplication (%s). Example latent variable: [%s].", latentVariables.size(), this.getClass().getName(), latentVariables.iterator().next()));
        }
        log.info("Grounding out model.");
        int groundCount = Grounding.groundAll(this.allRules, (AtomManager)this.atomManager, this.groundRuleStore);
        log.debug("Initializing objective terms for {} ground rules.", (Object)groundCount);
        int termCount = this.termGenerator.generateTerms(this.groundRuleStore, this.termStore);
        log.debug("Generated {} objective terms from {} ground rules.", (Object)termCount, (Object)groundCount);
    }

    protected void initLatentGroundModel() {
        this.latentGroundRuleStore = (GroundRuleStore)this.config.getNewObject(GROUND_RULE_STORE_KEY, GROUND_RULE_STORE_DEFAULT);
        this.latentTermStore = (TermStore)this.config.getNewObject(TERM_STORE_KEY, TERM_STORE_DEFAULT);
        log.info("Grounding out latent model.");
        int groundCount = Grounding.groundAll(this.allRules, (AtomManager)this.atomManager, this.latentGroundRuleStore);
        for (Map.Entry<RandomVariableAtom, ObservedAtom> entry : this.trainingMap.getTrainingMap().entrySet()) {
            this.latentGroundRuleStore.addGroundRule(new GroundValueConstraint(entry.getKey(), entry.getValue().getValue()));
        }
        log.debug("Initializing latent objective terms for {} ground rules.", (Object)(groundCount += this.trainingMap.getTrainingMap().size()));
        int termCount = this.termGenerator.generateTerms(this.latentGroundRuleStore, this.latentTermStore);
        log.debug("Generated {} latent objective terms from {} ground rules.", (Object)termCount, (Object)groundCount);
    }

    protected void computeMPEState() {
        this.termGenerator.updateWeights(this.groundRuleStore, this.termStore);
        this.reasoner.optimize(this.termStore);
    }

    protected void computeLatentMPEState() {
        this.termGenerator.updateWeights(this.latentGroundRuleStore, this.latentTermStore);
        this.reasoner.optimize(this.latentTermStore);
    }

    protected void computeObservedIncompatibility() {
        int i;
        this.setLabeledRandomVariables();
        for (i = 0; i < this.observedIncompatibility.length; ++i) {
            this.observedIncompatibility[i] = 0.0;
        }
        for (i = 0; i < this.mutableRules.size(); ++i) {
            for (GroundRule groundRule : this.groundRuleStore.getGroundRules(this.mutableRules.get(i))) {
                int n = i;
                this.observedIncompatibility[n] = this.observedIncompatibility[n] + ((WeightedGroundRule)groundRule).getIncompatibility();
            }
        }
    }

    protected void computeExpectedIncompatibility() {
        int i;
        this.computeMPEState();
        for (i = 0; i < this.expectedIncompatibility.length; ++i) {
            this.expectedIncompatibility[i] = 0.0;
        }
        for (i = 0; i < this.mutableRules.size(); ++i) {
            for (GroundRule groundRule : this.groundRuleStore.getGroundRules(this.mutableRules.get(i))) {
                int n = i;
                this.expectedIncompatibility[n] = this.expectedIncompatibility[n] + ((WeightedGroundRule)groundRule).getIncompatibility();
            }
        }
    }

    protected double computeLoss() {
        double loss = 0.0;
        for (int i = 0; i < this.mutableRules.size(); ++i) {
            loss += this.mutableRules.get(i).getWeight() * (this.observedIncompatibility[i] - this.expectedIncompatibility[i]);
        }
        return loss;
    }

    @Override
    public void close() {
        if (this.groundRuleStore != null) {
            this.groundRuleStore.close();
            this.groundRuleStore = null;
        }
        if (this.latentGroundRuleStore != null) {
            this.latentGroundRuleStore.close();
            this.latentGroundRuleStore = null;
        }
        if (this.termStore != null) {
            this.termStore.close();
            this.termStore = null;
        }
        if (this.latentTermStore != null) {
            this.latentTermStore.close();
            this.latentTermStore = null;
        }
        if (this.reasoner != null) {
            this.reasoner.close();
            this.reasoner = null;
        }
        this.termGenerator = null;
        this.trainingMap = null;
        this.atomManager = null;
        this.rvDB = null;
        this.observedDB = null;
        this.config = null;
    }

    protected void setLabeledRandomVariables() {
        for (Map.Entry<RandomVariableAtom, ObservedAtom> entry : this.trainingMap.getTrainingMap().entrySet()) {
            entry.getKey().setValue(entry.getValue().getValue());
        }
    }

    protected void setDefaultRandomVariables() {
        for (RandomVariableAtom atom : this.trainingMap.getTrainingMap().keySet()) {
            atom.setValue(0.0);
        }
        for (RandomVariableAtom atom : this.trainingMap.getLatentVariables()) {
            atom.setValue(0.0);
        }
    }

    protected PersistedAtomManager createAtomManager() {
        return new PersistedAtomManager(this.rvDB);
    }

    private void ensureTargets() {
        for (StandardPredicate predicate : this.observedDB.getRegisteredPredicates()) {
            if (this.observedDB.isClosed(predicate)) continue;
            for (ObservedAtom observedAtom : this.observedDB.getAllGroundObservedAtoms(predicate)) {
                GroundAtom otherAtom = this.atomManager.getAtom(observedAtom.getPredicate(), observedAtom.getArguments());
                if (otherAtom instanceof ObservedAtom) continue;
                RandomVariableAtom rvAtom = (RandomVariableAtom)otherAtom;
                rvAtom.setValue(0.0);
            }
        }
        this.atomManager.commitPersistedAtoms();
    }
}

