/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.application.learning.weight.search.objective;

import java.util.List;
import java.util.Map;
import org.linqs.psl.application.learning.weight.TrainingMap;
import org.linqs.psl.application.learning.weight.search.objective.ObjectiveFunction;
import org.linqs.psl.config.ConfigBundle;
import org.linqs.psl.model.atom.ObservedAtom;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.rule.WeightedRule;

public class ContinuousObjective
implements ObjectiveFunction {
    public static final String STAT_MAE = "MAE";
    public static final String STAT_MSE = "MSE";
    public static final String CONFIG_PREFIX = "continuousobjective";
    public static final String STAT_KEY = "continuousobjective.statistic";
    public static final String STAT_DEFAULT = "MSE";
    private String stat;

    public ContinuousObjective(String stat) {
        this.stat = stat;
        if (!stat.equals(STAT_MAE) && !stat.equals("MSE")) {
            throw new IllegalArgumentException("Unknown continuious statistic: " + stat);
        }
    }

    public ContinuousObjective(ConfigBundle config) {
        this.stat = config.getString(STAT_KEY, "MSE").toUpperCase();
        if (!this.stat.equals(STAT_MAE) && !this.stat.equals("MSE")) {
            throw new IllegalArgumentException("Unknown continuious statistic: " + this.stat);
        }
    }

    @Override
    public double compute(List<WeightedRule> mutableRules, double[] observedIncompatibility, double[] expectedIncompatibility, TrainingMap trainingMap) {
        double error = 0.0;
        boolean square = this.stat.equals("MSE");
        for (Map.Entry<RandomVariableAtom, ObservedAtom> entry : trainingMap.getTrainingMap().entrySet()) {
            if (square) {
                error += Math.pow(entry.getKey().getValue() - entry.getValue().getValue(), 2.0);
                continue;
            }
            error += Math.abs(entry.getKey().getValue() - entry.getValue().getValue());
        }
        return error / (double)trainingMap.getTrainingMap().size();
    }
}

