/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.reasoner.admm;

import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CyclicBarrier;
import org.linqs.psl.config.ConfigBundle;
import org.linqs.psl.model.rule.GroundRule;
import org.linqs.psl.model.rule.WeightedGroundRule;
import org.linqs.psl.reasoner.Reasoner;
import org.linqs.psl.reasoner.ThreadPool;
import org.linqs.psl.reasoner.admm.term.ADMMTermStore;
import org.linqs.psl.reasoner.admm.term.LocalVariable;
import org.linqs.psl.reasoner.inspector.ReasonerInspector;
import org.linqs.psl.reasoner.term.TermStore;
import org.linqs.psl.util.Parallel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ADMMReasoner
extends Reasoner {
    private static final Logger log = LoggerFactory.getLogger(ADMMReasoner.class);
    public static final String CONFIG_PREFIX = "admmreasoner";
    public static final String MAX_ITER_KEY = "admmreasoner.maxiterations";
    public static final int MAX_ITER_DEFAULT = 25000;
    public static final String STEP_SIZE_KEY = "admmreasoner.stepsize";
    public static final float STEP_SIZE_DEFAULT = 1.0f;
    public static final String EPSILON_ABS_KEY = "admmreasoner.epsilonabs";
    public static final float EPSILON_ABS_DEFAULT = 1.0E-5f;
    public static final String EPSILON_REL_KEY = "admmreasoner.epsilonrel";
    public static final float EPSILON_REL_DEFAULT = 0.001f;
    public static final String NUM_THREADS_KEY = "admmreasoner.numthreads";
    public static final int NUM_THREADS_DEFAULT = Parallel.NUM_THREADS;
    private static final float LOWER_BOUND = 0.0f;
    private static final float UPPER_BOUND = 1.0f;
    private static final int MIN_BLOCK_SIZE = 20;
    private static final int TERM_ITERATIONS = 1000;
    private static final int LOG_PERIOD = 50;
    private final float stepSize;
    private final int numThreads;
    private float epsilonRel;
    private float epsilonAbs;
    private float lagrangePenalty;
    private float augmentedLagrangePenalty;
    private int maxIter;
    private float[] consensusValues;

    public ADMMReasoner(ConfigBundle config) {
        super(config);
        this.maxIter = config.getInt(MAX_ITER_KEY, 25000);
        this.stepSize = config.getFloat(STEP_SIZE_KEY, 1.0f);
        this.epsilonAbs = config.getFloat(EPSILON_ABS_KEY, 1.0E-5f);
        if (this.epsilonAbs <= 0.0f) {
            throw new IllegalArgumentException("Property admmreasoner.epsilonabs must be positive.");
        }
        this.epsilonRel = config.getFloat(EPSILON_REL_KEY, 0.001f);
        if (this.epsilonRel <= 0.0f) {
            throw new IllegalArgumentException("Property admmreasoner.epsilonrel must be positive.");
        }
        this.numThreads = config.getInt(NUM_THREADS_KEY, NUM_THREADS_DEFAULT);
        if (this.numThreads <= 0) {
            throw new IllegalArgumentException("Property admmreasoner.numthreads must be positive.");
        }
    }

    public int getMaxIter() {
        return this.maxIter;
    }

    public void setMaxIter(int maxIter) {
        this.maxIter = maxIter;
    }

    public float getEpsilonRel() {
        return this.epsilonRel;
    }

    public void setEpsilonRel(float epsilonRel) {
        this.epsilonRel = epsilonRel;
    }

    public float getEpsilonAbs() {
        return this.epsilonAbs;
    }

    public void setEpsilonAbs(float epsilonAbs) {
        this.epsilonAbs = epsilonAbs;
    }

    public float getLagrangianPenalty() {
        return this.lagrangePenalty;
    }

    public float getAugmentedLagrangianPenalty() {
        return this.augmentedLagrangePenalty;
    }

    public double getDualIncompatibility(GroundRule groundRule, ADMMTermStore termStore) {
        for (Integer termIndex : termStore.getTermIndices((WeightedGroundRule)groundRule)) {
            for (LocalVariable localVariable : termStore.get(termIndex).getVariables()) {
                this.consensusValues[localVariable.getGlobalId()] = localVariable.getValue();
            }
        }
        termStore.updateVariables(this.consensusValues);
        return ((WeightedGroundRule)groundRule).getIncompatibility();
    }

    @Override
    public void optimize(TermStore baseTermStore) {
        int iteration;
        if (!(baseTermStore instanceof ADMMTermStore)) {
            throw new IllegalArgumentException("ADMMReasoner requires an ADMMTermStore");
        }
        ADMMTermStore termStore = (ADMMTermStore)baseTermStore;
        log.debug("Performing optimization with {} variables and {} terms.", (Object)termStore.getNumGlobalVariables(), (Object)termStore.size());
        this.consensusValues = new float[termStore.getNumGlobalVariables()];
        int blockSize = Math.max(20, (int)((double)termStore.size() / (double)this.numThreads / 1000.0));
        log.trace("Using a block size of {}.", (Object)blockSize);
        SyncCounter termCounter = new SyncCounter((int)Math.ceil((float)termStore.size() / (float)blockSize));
        SyncCounter variableCounter = new SyncCounter((int)Math.ceil((float)termStore.getNumGlobalVariables() / (float)blockSize));
        ADMMTask[] tasks = new ADMMTask[this.numThreads];
        CyclicBarrier termUpdateCompleteBarrier = new CyclicBarrier(this.numThreads);
        CyclicBarrier workerStartBarrier = new CyclicBarrier(this.numThreads + 1);
        CyclicBarrier workerEndBarrier = new CyclicBarrier(this.numThreads + 1);
        ThreadPool threadPool = new ThreadPool();
        for (int i = 0; i < this.numThreads; ++i) {
            tasks[i] = new ADMMTask(i, termUpdateCompleteBarrier, workerStartBarrier, workerEndBarrier, termCounter, variableCounter, termStore, this.consensusValues, blockSize);
            threadPool.submit(tasks[i]);
        }
        float primalRes = Float.POSITIVE_INFINITY;
        float dualRes = Float.POSITIVE_INFINITY;
        float epsilonPrimal = 0.0f;
        float epsilonDual = 0.0f;
        float epsilonAbsTerm = (float)(Math.sqrt(termStore.getNumLocalVariables()) * (double)this.epsilonAbs);
        float AxNorm = 0.0f;
        float BzNorm = 0.0f;
        float AyNorm = 0.0f;
        for (iteration = 1; (primalRes > epsilonPrimal || dualRes > epsilonDual) && iteration <= this.maxIter; ++iteration) {
            try {
                termCounter.reset();
                variableCounter.reset();
                workerStartBarrier.await();
                workerEndBarrier.await();
            }
            catch (InterruptedException e) {
                throw new RuntimeException(e);
            }
            catch (BrokenBarrierException e) {
                throw new RuntimeException(e);
            }
            primalRes = 0.0f;
            dualRes = 0.0f;
            AxNorm = 0.0f;
            BzNorm = 0.0f;
            AyNorm = 0.0f;
            this.lagrangePenalty = 0.0f;
            this.augmentedLagrangePenalty = 0.0f;
            for (ADMMTask task : tasks) {
                primalRes += task.primalResInc;
                dualRes += task.dualResInc;
                AxNorm += task.AxNormInc;
                BzNorm += task.BzNormInc;
                AyNorm += task.AyNormInc;
                this.lagrangePenalty += task.lagrangePenalty;
                this.augmentedLagrangePenalty += task.augmentedLagrangePenalty;
            }
            primalRes = (float)Math.sqrt(primalRes);
            dualRes = (float)((double)this.stepSize * Math.sqrt(dualRes));
            epsilonPrimal = (float)((double)epsilonAbsTerm + (double)this.epsilonRel * Math.max(Math.sqrt(AxNorm), Math.sqrt(BzNorm)));
            epsilonDual = (float)((double)epsilonAbsTerm + (double)this.epsilonRel * Math.sqrt(AyNorm));
            if (iteration % 50 == 0) {
                log.trace("Residuals at iteration {} -- Primal: {} -- Dual: {}", iteration, Float.valueOf(primalRes), Float.valueOf(dualRes));
                log.trace("--------- Epsilon primal: {} -- Epsilon dual: {}", (Object)Float.valueOf(epsilonPrimal), (Object)Float.valueOf(epsilonDual));
            }
            if (this.inspector == null) continue;
            log.debug("Updating random variable atoms with consensus values for inspector");
            termStore.updateVariables(this.consensusValues);
            if (this.inspector.update(this, new ADMMStatus(iteration, primalRes, dualRes))) continue;
            log.info("Stopping ADMM iterations on advice from inspector");
            break;
        }
        for (ADMMTask task : tasks) {
            task.done = true;
        }
        try {
            workerStartBarrier.await();
        }
        catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
        catch (BrokenBarrierException e) {
            throw new RuntimeException(e);
        }
        threadPool.shutdownAndWait();
        log.info("Optimization completed in {} iterations. Primal res.: {}, Dual res.: {}", iteration - 1, Float.valueOf(primalRes), Float.valueOf(dualRes));
        termStore.updateVariables(this.consensusValues);
    }

    @Override
    public void close() {
    }

    private static class ADMMStatus
    extends ReasonerInspector.IterativeReasonerStatus {
        public double primalResidual;
        public double dualResidual;

        public ADMMStatus(int iteration, double primalResidual, double dualResidual) {
            super(iteration);
            this.primalResidual = primalResidual;
            this.dualResidual = dualResidual;
        }

        @Override
        public String toString() {
            return String.format("%s, primal: %f, dual: %f", super.toString(), this.primalResidual, this.dualResidual);
        }
    }

    private static class SyncCounter {
        private final int max;
        private int count;

        public SyncCounter(int max) {
            this.max = max;
            this.count = 0;
        }

        public synchronized int next() {
            if (this.count >= this.max) {
                return -1;
            }
            return this.count++;
        }

        public synchronized void reset() {
            this.count = 0;
        }
    }

    private class ADMMTask
    implements Runnable {
        public volatile boolean done;
        private final int threadIndex;
        private final int blockSize;
        private final SyncCounter termCounter;
        private final SyncCounter variableCounter;
        private float[] consensusValues;
        private final ADMMTermStore termStore;
        private final CyclicBarrier termUpdateCompleteBarrier;
        private final CyclicBarrier workerStartBarrier;
        private final CyclicBarrier workerEndBarrier;
        public float primalResInc;
        public float dualResInc;
        public float AxNormInc;
        public float BzNormInc;
        public float AyNormInc;
        protected float lagrangePenalty;
        protected float augmentedLagrangePenalty;

        public ADMMTask(int threadIndex, CyclicBarrier termUpdateCompleteBarrier, CyclicBarrier workerStartBarrier, CyclicBarrier workerEndBarrier, SyncCounter termCounter, SyncCounter variableCounter, ADMMTermStore termStore, float[] consensusValues, int blockSize) {
            this.termUpdateCompleteBarrier = termUpdateCompleteBarrier;
            this.workerStartBarrier = workerStartBarrier;
            this.workerEndBarrier = workerEndBarrier;
            this.threadIndex = threadIndex;
            this.termCounter = termCounter;
            this.variableCounter = variableCounter;
            this.blockSize = blockSize;
            this.consensusValues = consensusValues;
            this.termStore = termStore;
            this.done = false;
            this.primalResInc = 0.0f;
            this.dualResInc = 0.0f;
            this.AxNormInc = 0.0f;
            this.BzNormInc = 0.0f;
            this.AyNormInc = 0.0f;
            this.lagrangePenalty = 0.0f;
            this.augmentedLagrangePenalty = 0.0f;
        }

        private void awaitUninterruptibly(CyclicBarrier b) {
            try {
                b.await();
            }
            catch (InterruptedException e) {
                throw new RuntimeException(e);
            }
            catch (BrokenBarrierException e) {
                throw new RuntimeException(e);
            }
        }

        @Override
        public void run() {
            int numTerms = this.termStore.size();
            int numVariables = this.termStore.getNumGlobalVariables();
            boolean iteration = true;
            while (true) {
                int innerBlockIndex;
                this.awaitUninterruptibly(this.workerStartBarrier);
                if (this.done) break;
                int blockIndex = this.termCounter.next();
                while (blockIndex != -1) {
                    int termIndex;
                    for (innerBlockIndex = 0; innerBlockIndex < this.blockSize && (termIndex = blockIndex * this.blockSize + innerBlockIndex) < numTerms; ++innerBlockIndex) {
                        this.termStore.get(termIndex).updateLagrange(ADMMReasoner.this.stepSize, this.consensusValues);
                        this.termStore.get(termIndex).minimize(ADMMReasoner.this.stepSize, this.consensusValues);
                    }
                    blockIndex = this.termCounter.next();
                }
                this.awaitUninterruptibly(this.termUpdateCompleteBarrier);
                this.primalResInc = 0.0f;
                this.dualResInc = 0.0f;
                this.AxNormInc = 0.0f;
                this.BzNormInc = 0.0f;
                this.AyNormInc = 0.0f;
                this.lagrangePenalty = 0.0f;
                this.augmentedLagrangePenalty = 0.0f;
                blockIndex = this.variableCounter.next();
                while (blockIndex != -1) {
                    int variableIndex;
                    for (innerBlockIndex = 0; innerBlockIndex < this.blockSize && (variableIndex = blockIndex * this.blockSize + innerBlockIndex) < numVariables; ++innerBlockIndex) {
                        float total = 0.0f;
                        int numLocalVariables = this.termStore.getLocalVariables(variableIndex).size();
                        for (int localVarIndex = 0; localVarIndex < numLocalVariables; ++localVarIndex) {
                            LocalVariable localVariable = this.termStore.getLocalVariables(variableIndex).get(localVarIndex);
                            total += localVariable.getValue() + localVariable.getLagrange() / ADMMReasoner.this.stepSize;
                            this.AxNormInc += localVariable.getValue() * localVariable.getValue();
                            this.AyNormInc += localVariable.getLagrange() * localVariable.getLagrange();
                        }
                        float newConsensusValue = total / (float)numLocalVariables;
                        newConsensusValue = Math.max(Math.min(newConsensusValue, 1.0f), 0.0f);
                        float diff = this.consensusValues[variableIndex] - newConsensusValue;
                        this.dualResInc += diff * diff * (float)numLocalVariables;
                        this.BzNormInc += newConsensusValue * newConsensusValue * (float)numLocalVariables;
                        this.consensusValues[variableIndex] = newConsensusValue;
                        for (int localVarIndex = 0; localVarIndex < numLocalVariables; ++localVarIndex) {
                            LocalVariable localVariable = this.termStore.getLocalVariables(variableIndex).get(localVarIndex);
                            diff = localVariable.getValue() - newConsensusValue;
                            this.primalResInc += diff * diff;
                            this.lagrangePenalty += localVariable.getLagrange() * (localVariable.getValue() - this.consensusValues[variableIndex]);
                            this.augmentedLagrangePenalty = (float)((double)this.augmentedLagrangePenalty + 0.5 * (double)ADMMReasoner.this.stepSize * Math.pow(localVariable.getValue() - this.consensusValues[variableIndex], 2.0));
                        }
                    }
                    blockIndex = this.variableCounter.next();
                }
                this.awaitUninterruptibly(this.workerEndBarrier);
            }
        }
    }
}

