/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.evaluation.statistics;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.linqs.psl.database.Database;
import org.linqs.psl.database.Queries;
import org.linqs.psl.evaluation.statistics.CategoricalPredictionStatistics;
import org.linqs.psl.evaluation.statistics.PredictionComparator;
import org.linqs.psl.evaluation.statistics.filter.AtomFilter;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.atom.ObservedAtom;
import org.linqs.psl.model.predicate.StandardPredicate;
import org.linqs.psl.model.term.Constant;

public class CategoricalPredictionComparator
implements PredictionComparator {
    private final Database result;
    private Database baseline;
    private Set<Integer> catIndexSet;

    public CategoricalPredictionComparator(Database result) {
        this(result, null, null);
    }

    public CategoricalPredictionComparator(Database result, Database baseline, int[] categoryIndexes) {
        this.result = result;
        this.baseline = baseline;
        this.catIndexSet = new HashSet<Integer>();
        if (categoryIndexes != null) {
            for (int catIndex : categoryIndexes) {
                this.catIndexSet.add(catIndex);
            }
        }
    }

    public void setCategoryIndexes(int[] categoryIndexes) {
        this.catIndexSet.clear();
        if (categoryIndexes != null) {
            for (int catIndex : categoryIndexes) {
                this.catIndexSet.add(catIndex);
            }
        }
    }

    @Override
    public void setBaseline(Database db) {
        this.baseline = db;
    }

    @Override
    public void setResultFilter(AtomFilter af) {
        throw new UnsupportedOperationException();
    }

    @Override
    public CategoricalPredictionStatistics compare(StandardPredicate predicate) {
        if (this.catIndexSet.size() == 0) {
            throw new IllegalStateException("No category indexes have been proveded.");
        }
        if (this.catIndexSet.size() >= predicate.getArity()) {
            throw new IllegalStateException(String.format("Too many category indexes for %s. Found: %d, Max: %d.", predicate.getName(), this.catIndexSet.size(), predicate.getArity() - 1));
        }
        int hits = 0;
        int misses = 0;
        Set<GroundAtom> bestCats = this.getBestCategories(predicate);
        for (GroundAtom truthAtom : Queries.getAllAtoms(this.baseline, predicate)) {
            if (!(truthAtom instanceof ObservedAtom) || truthAtom.getValue() < 1.0) continue;
            if (bestCats.contains(truthAtom)) {
                ++hits;
                continue;
            }
            ++misses;
        }
        return new CategoricalPredictionStatistics(hits, misses);
    }

    private Set<GroundAtom> getBestCategories(StandardPredicate predicate) {
        int numArgs = predicate.getArity();
        Map bestCats = null;
        for (GroundAtom atom : Queries.getAllAtoms(this.result, predicate)) {
            Map ignoreWarning;
            bestCats = ignoreWarning = (Map)this.putBestCats(bestCats, atom, 0);
        }
        HashSet<GroundAtom> rtn = new HashSet<GroundAtom>();
        this.collectBestCats(bestCats, rtn);
        return rtn;
    }

    private Object putBestCats(Object currentNode, GroundAtom atom, int argIndex) {
        Map<Constant, Object> bestCats;
        assert (argIndex <= atom.getArity());
        if (this.catIndexSet.contains(argIndex)) {
            return this.putBestCats(currentNode, atom, argIndex + 1);
        }
        if (argIndex == atom.getArity()) {
            if (currentNode == null) {
                return atom;
            }
            GroundAtom oldBest = (GroundAtom)currentNode;
            if (atom.getValue() > oldBest.getValue()) {
                return atom;
            }
            return oldBest;
        }
        if (currentNode == null) {
            bestCats = new HashMap();
        } else {
            Map ignoreWarning = (Map)currentNode;
            bestCats = ignoreWarning;
        }
        Constant arg = atom.getArguments()[argIndex];
        bestCats.put(arg, this.putBestCats(bestCats.get(arg), atom, argIndex + 1));
        return bestCats;
    }

    private void collectBestCats(Map<Constant, Object> bestCats, Set<GroundAtom> result) {
        for (Object value : bestCats.values()) {
            if (value instanceof GroundAtom) {
                result.add((GroundAtom)value);
                continue;
            }
            Map ignoreWarning = (Map)value;
            this.collectBestCats(ignoreWarning, result);
        }
    }
}

