/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.evaluation.statistics.inspector;

import org.linqs.psl.config.ConfigBundle;
import org.linqs.psl.database.Database;
import org.linqs.psl.database.Queries;
import org.linqs.psl.evaluation.statistics.CategoricalPredictionComparator;
import org.linqs.psl.evaluation.statistics.CategoricalPredictionStatistics;
import org.linqs.psl.model.predicate.StandardPredicate;
import org.linqs.psl.reasoner.Reasoner;
import org.linqs.psl.reasoner.inspector.DatabaseReasonerInspector;
import org.linqs.psl.reasoner.inspector.ReasonerInspector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CategoricalAccuracyInspector
extends DatabaseReasonerInspector {
    private static final Logger log = LoggerFactory.getLogger(CategoricalAccuracyInspector.class);

    public CategoricalAccuracyInspector(ConfigBundle config) {
        super(config);
    }

    @Override
    public boolean update(Reasoner reasoner, ReasonerInspector.ReasonerStatus status) {
        log.info("Reasoner inspection update -- " + status);
        Database rvDatabase = this.getRandomVariableDatabase();
        Database truthDatabase = this.getTruthDatabase(rvDatabase);
        CategoricalPredictionComparator comparator = new CategoricalPredictionComparator(rvDatabase);
        comparator.setBaseline(truthDatabase);
        for (StandardPredicate targetPredicate : rvDatabase.getRegisteredPredicates()) {
            if (Queries.countAllGroundAtoms(truthDatabase, targetPredicate) == 0) continue;
            int[] categoryIndexes = new int[]{targetPredicate.getArity() - 1};
            comparator.setCategoryIndexes(categoryIndexes);
            CategoricalPredictionStatistics stats = comparator.compare(targetPredicate);
            double accuracy = stats.getAccuracy();
            double error = stats.getError();
            log.info("{} -- Accuracy: {}, Error: {}", targetPredicate.getName(), accuracy, (int)error);
        }
        truthDatabase.close();
        log.info("Reasoner inspection update complete");
        return true;
    }
}

