/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.reasoner.bool;

import java.util.Random;
import org.linqs.psl.application.util.GroundRules;
import org.linqs.psl.config.ConfigBundle;
import org.linqs.psl.model.rule.WeightedGroundRule;
import org.linqs.psl.reasoner.Reasoner;
import org.linqs.psl.reasoner.inspector.ReasonerInspector;
import org.linqs.psl.reasoner.term.TermStore;
import org.linqs.psl.reasoner.term.blocker.ConstraintBlockerTerm;
import org.linqs.psl.reasoner.term.blocker.ConstraintBlockerTermStore;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class BooleanMCSat
extends Reasoner {
    private static final Logger log = LoggerFactory.getLogger(BooleanMCSat.class);
    public static final String CONFIG_PREFIX = "booleanmcsat";
    public static final String NUM_SAMPLES_KEY = "booleanmcsat.numsamples";
    public static final int NUM_SAMPLES_DEFAULT = 2500;
    public static final String NUM_BURN_IN_KEY = "booleanmcsat.numburnin";
    public static final int NUM_BURN_IN_DEFAULT = 500;
    private final Random rand = new Random();
    private final int numSamples;
    private final int numBurnIn;

    public BooleanMCSat(ConfigBundle config) {
        super(config);
        this.numSamples = config.getInt(NUM_SAMPLES_KEY, 2500);
        if (this.numSamples <= 0) {
            throw new IllegalArgumentException("Number of samples must be positive.");
        }
        this.numBurnIn = config.getInt(NUM_BURN_IN_KEY, 500);
        if (this.numSamples <= 0) {
            throw new IllegalArgumentException("Number of burn in samples must be positive.");
        }
        if (this.numBurnIn >= this.numSamples) {
            throw new IllegalArgumentException("Number of burn in samples must be less than number of samples.");
        }
    }

    @Override
    public void optimize(TermStore termStore) {
        if (!(termStore instanceof ConstraintBlockerTermStore)) {
            throw new IllegalArgumentException("ConstraintBlockerTermStore required.");
        }
        ConstraintBlockerTermStore blocker = (ConstraintBlockerTermStore)termStore;
        blocker.randomlyInitialize();
        double[][] totals = new double[blocker.size()][];
        for (int i = 0; i < blocker.size(); ++i) {
            totals[i] = new double[blocker.get(i).size()];
        }
        log.info("Beginning inference.");
        for (int sampleIndex = 0; sampleIndex < this.numSamples; ++sampleIndex) {
            double infeasbility;
            int blockIndex;
            for (blockIndex = 0; blockIndex < blocker.size(); ++blockIndex) {
                ConstraintBlockerTerm block = blocker.get(blockIndex);
                if (block.size() == 0) continue;
                double[] probabilities = new double[block.getExactlyOne() ? block.size() : block.size() + 1];
                for (int atomIndex = 0; atomIndex < probabilities.length; ++atomIndex) {
                    for (int i = 0; i < block.size(); ++i) {
                        if (i == atomIndex) {
                            block.getAtoms()[i].setValue(1.0);
                            continue;
                        }
                        block.getAtoms()[i].setValue(0.0);
                    }
                    probabilities[atomIndex] = this.computeProbability(block.getIncidentGRs());
                }
                double[] sample = this.sampleWithProbability(probabilities);
                for (int atomIndex = 0; atomIndex < block.getAtoms().length; ++atomIndex) {
                    block.getAtoms()[atomIndex].setValue(sample[atomIndex]);
                    if (sampleIndex < this.numBurnIn) continue;
                    double[] dArray = totals[blockIndex];
                    int n = atomIndex;
                    dArray[n] = dArray[n] + sample[atomIndex];
                }
            }
            if (this.inspector == null || sampleIndex < this.numBurnIn) continue;
            for (blockIndex = 0; blockIndex < blocker.size(); ++blockIndex) {
                for (int atomIndex = 0; atomIndex < blocker.get(blockIndex).size(); ++atomIndex) {
                    blocker.get(blockIndex).getAtoms()[atomIndex].setValue(totals[blockIndex][atomIndex] / (double)(this.numSamples - this.numBurnIn));
                }
            }
            double incompatibility = GroundRules.getTotalWeightedIncompatibility(blocker.getGroundRuleStore().getCompatibilityRules());
            if (this.inspector.update(this, new MCSatStatus(sampleIndex, incompatibility, infeasbility = GroundRules.getInfeasibilityNorm(blocker.getGroundRuleStore().getConstraintRules())))) continue;
            log.info("Stopping MCSat iterations on advice from inspector");
            break;
        }
        log.info("Inference complete.");
        for (int blockIndex = 0; blockIndex < blocker.size(); ++blockIndex) {
            for (int atomIndex = 0; atomIndex < blocker.get(blockIndex).size(); ++atomIndex) {
                blocker.get(blockIndex).getAtoms()[atomIndex].setValue(totals[blockIndex][atomIndex] / (double)(this.numSamples - this.numBurnIn));
            }
        }
    }

    private double computeProbability(WeightedGroundRule[] incidentGRs) {
        double probability = 0.0;
        for (WeightedGroundRule groundRule : incidentGRs) {
            probability += groundRule.getWeight() * groundRule.getIncompatibility();
        }
        return Math.exp(-1.0 * probability);
    }

    private double[] sampleWithProbability(double[] distribution) {
        double total = 0.0;
        for (double pValue : distribution) {
            total += pValue;
        }
        int i = 0;
        while (i < distribution.length) {
            int n = i++;
            distribution[n] = distribution[n] / total;
        }
        double[] sample = new double[distribution.length];
        double cutoff = this.rand.nextDouble();
        total = 0.0;
        for (int i2 = 0; i2 < distribution.length; ++i2) {
            if (!((total += distribution[i2]) >= cutoff)) continue;
            sample[i2] = 1.0;
            return sample;
        }
        sample[sample.length - 1] = 1.0;
        return sample;
    }

    @Override
    public void close() {
    }

    private static class MCSatStatus
    extends ReasonerInspector.IterativeReasonerStatus {
        public double incompatibility;
        public double infeasbility;

        public MCSatStatus(int iteration, double incompatibility, double infeasbility) {
            super(iteration);
            this.incompatibility = incompatibility;
            this.infeasbility = infeasbility;
        }

        @Override
        public String toString() {
            return String.format("%s, incompatibility: %f, infeasbility: %f", super.toString(), this.incompatibility, this.infeasbility);
        }
    }
}

