/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.application.learning.weight.search.objective;

import java.util.List;
import java.util.Map;
import org.linqs.psl.application.learning.weight.TrainingMap;
import org.linqs.psl.application.learning.weight.search.objective.ObjectiveFunction;
import org.linqs.psl.config.ConfigBundle;
import org.linqs.psl.model.atom.ObservedAtom;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.rule.WeightedRule;
import org.linqs.psl.util.MathUtils;

public class DiscreteObjective
implements ObjectiveFunction {
    public static final String STAT_F1 = "F1";
    public static final String STAT_ACCURACY = "accuracy";
    public static final String STAT_PRECISION = "precision";
    public static final String STAT_RECALL = "recall";
    public static final String CONFIG_PREFIX = "discreteobjective";
    public static final String STAT_KEY = "discreteobjective.statistic";
    public static final String STAT_DEFAULT = "F1";
    public static final String THRESHOLD_KEY = "discreteobjective.truththreshold";
    public static final double THRESHOLD_DEFAULT = 0.5;
    private String stat;
    private double threshold;

    public DiscreteObjective(ConfigBundle config) {
        this(config.getString(STAT_KEY, "F1").toUpperCase(), config.getDouble(THRESHOLD_KEY, 0.5));
    }

    public DiscreteObjective(String stat, double threshold) {
        if (!(stat.equals("F1") || stat.equals(STAT_ACCURACY) || stat.equals(STAT_PRECISION) || stat.equals(STAT_RECALL))) {
            throw new IllegalArgumentException("Unknown statistic: " + stat);
        }
        if (threshold < 0.0 || threshold > 1.0) {
            throw new IllegalArgumentException("Threshold must be in [0, 1], found: " + threshold);
        }
        this.stat = stat;
        this.threshold = threshold;
    }

    @Override
    public double compute(List<WeightedRule> mutableRules, double[] observedIncompatibility, double[] expectedIncompatibility, TrainingMap trainingMap) {
        int tp = 0;
        int fp = 0;
        int tn = 0;
        int fn = 0;
        for (Map.Entry<RandomVariableAtom, ObservedAtom> entry : trainingMap.getTrainingMap().entrySet()) {
            boolean predicated;
            boolean expected = entry.getValue().getValue() >= this.threshold;
            boolean bl = predicated = entry.getKey().getValue() >= this.threshold;
            if (predicated && expected) {
                ++tp;
                continue;
            }
            if (!predicated && expected) {
                ++fn;
                continue;
            }
            if (predicated && !expected) {
                ++fp;
                continue;
            }
            ++tn;
        }
        if (this.stat.equals("F1")) {
            return 1.0 - this.computeF1(tp, fp, tn, fn);
        }
        if (this.stat.equals(STAT_ACCURACY)) {
            return 1.0 - this.computeAccuracy(tp, fp, tn, fn);
        }
        if (this.stat.equals(STAT_PRECISION)) {
            return 1.0 - this.computePrecision(tp, fp, tn, fn);
        }
        if (this.stat.equals(STAT_RECALL)) {
            return 1.0 - this.computeRecall(tp, fp, tn, fn);
        }
        throw new IllegalArgumentException("Unknown statistic: " + this.stat);
    }

    public double computeF1(int tp, int fp, int tn, int fn) {
        double recall;
        double precision = this.computePrecision(tp, fp, tn, fn);
        if (MathUtils.isZero(precision + (recall = this.computeRecall(tp, fp, tn, fn)))) {
            return 0.0;
        }
        return 2.0 * (precision * recall) / (precision + recall);
    }

    public double computeAccuracy(int tp, int fp, int tn, int fn) {
        return (double)(tp + tn) / (double)(tp + fp + tn + fn);
    }

    public double computePrecision(int tp, int fp, int tn, int fn) {
        if (tp + fp == 0) {
            return 0.0;
        }
        return (double)tp / (double)(tp + fp);
    }

    public double computeRecall(int tp, int fp, int tn, int fn) {
        if (tp + fn == 0) {
            return 0.0;
        }
        return (double)tp / (double)(tp + fn);
    }
}

