/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.application.learning.weight.em;

import java.util.List;
import org.linqs.psl.application.learning.weight.VotedPerceptron;
import org.linqs.psl.config.ConfigBundle;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.WeightedRule;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class ExpectationMaximization
extends VotedPerceptron {
    private static final Logger log = LoggerFactory.getLogger(ExpectationMaximization.class);
    public static final String CONFIG_PREFIX = "em";
    public static final String ITER_KEY = "em.iterations";
    public static final int ITER_DEFAULT = 10;
    public static final String RESET_SCHEDULE_KEY = "em.resetschedule";
    public static final boolean RESET_SCHEDULE_DEFAULT = true;
    public static final String TOLERANCE_KEY = "em.tolerance";
    public static final double TOLERANCE_DEFAULT = 0.001;
    protected final int iterations;
    protected final double tolerance;
    protected final boolean resetSchedule;
    protected int emIteration;

    public ExpectationMaximization(List<Rule> rules, Database rvDB, Database observedDB, ConfigBundle config) {
        super(rules, rvDB, observedDB, true, config);
        this.iterations = config.getInt(ITER_KEY, 10);
        this.tolerance = config.getDouble(TOLERANCE_KEY, 0.001);
        this.resetSchedule = config.getBoolean(RESET_SCHEDULE_KEY, true);
    }

    @Override
    protected void doLearn() {
        double[] previousWeights = new double[this.mutableRules.size()];
        for (int i = 0; i < previousWeights.length; ++i) {
            previousWeights[i] = ((WeightedRule)this.mutableRules.get(i)).getWeight();
        }
        this.emIteration = 0;
        while (this.emIteration < this.iterations) {
            log.debug("Beginning EM iteration {} of {}", (Object)this.emIteration, (Object)this.iterations);
            this.eStep();
            this.mStep();
            double change = 0.0;
            for (int i = 0; i < this.mutableRules.size(); ++i) {
                change += Math.pow(previousWeights[i] - ((WeightedRule)this.mutableRules.get(i)).getWeight(), 2.0);
                previousWeights[i] = ((WeightedRule)this.mutableRules.get(i)).getWeight();
            }
            change = Math.sqrt(change);
            double loss = this.getLoss();
            double regularizer = this.computeRegularizer();
            double objective = loss + regularizer;
            log.info("Finished EM iteration {} with m-step norm {}. Loss: {}, regularizer: {}, objective: {}", this.emIteration, change, loss, regularizer, objective);
            if (change <= this.tolerance) {
                log.info("EM converged.");
                break;
            }
            ++this.emIteration;
        }
    }

    protected void eStep() {
        this.computeLatentMPEState();
    }

    protected void mStep() {
        super.doLearn();
    }

    @Override
    protected double getStepSize(int innerIteration) {
        if (this.scheduleStepSize && !this.resetSchedule) {
            return this.baseStepSize / (double)(this.emIteration * this.numSteps + innerIteration + 1);
        }
        return super.getStepSize(this.emIteration);
    }
}

