/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.evaluation.statistics;

import org.linqs.psl.database.Database;
import org.linqs.psl.database.Queries;
import org.linqs.psl.evaluation.statistics.ContinuousPredictionStatistics;
import org.linqs.psl.evaluation.statistics.ResultComparator;
import org.linqs.psl.evaluation.statistics.filter.AtomFilter;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.atom.ObservedAtom;
import org.linqs.psl.model.predicate.StandardPredicate;

public class ContinuousPredictionComparator
implements ResultComparator {
    private final Database result;
    private Database baseline;

    public ContinuousPredictionComparator(Database result) {
        this(result, null);
    }

    public ContinuousPredictionComparator(Database result, Database baseline) {
        this.result = result;
        this.baseline = baseline;
    }

    @Override
    public void setBaseline(Database db) {
        this.baseline = db;
    }

    @Override
    public void setResultFilter(AtomFilter filter) {
        throw new UnsupportedOperationException();
    }

    public ContinuousPredictionStatistics compare(StandardPredicate predicate) {
        int count = 0;
        double absoluteError = 0.0;
        double squaredError = 0.0;
        for (GroundAtom truthAtom : Queries.getAllAtoms(this.baseline, predicate)) {
            GroundAtom resultAtom;
            if (!(truthAtom instanceof ObservedAtom) || (resultAtom = this.result.getAtom(truthAtom.getPredicate(), truthAtom.getArguments())) instanceof ObservedAtom) continue;
            ++count;
            absoluteError += Math.abs(truthAtom.getValue() - resultAtom.getValue());
            squaredError += Math.pow(truthAtom.getValue() - resultAtom.getValue(), 2.0);
        }
        return new ContinuousPredictionStatistics(count, absoluteError, squaredError);
    }
}

