/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.reasoner.admm.term;

import cern.colt.matrix.tfloat.FloatMatrix2D;
import cern.colt.matrix.tfloat.algo.decomposition.DenseFloatCholeskyDecomposition;
import cern.colt.matrix.tfloat.impl.DenseFloatMatrix2D;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Semaphore;
import org.linqs.psl.reasoner.admm.term.ADMMObjectiveTerm;
import org.linqs.psl.reasoner.admm.term.LocalVariable;
import org.linqs.psl.reasoner.term.WeightedTerm;
import org.linqs.psl.util.HashCode;

public abstract class SquaredHyperplaneTerm
extends ADMMObjectiveTerm
implements WeightedTerm {
    protected final List<Float> coeffs;
    protected final float constant;
    protected float weight;
    private FloatMatrix2D L;
    private static Map<DenseFloatMatrix2DWithHashcode, FloatMatrix2D> lCache = new HashMap<DenseFloatMatrix2DWithHashcode, FloatMatrix2D>();
    private static final Semaphore matrixSemaphore = new Semaphore(1);

    SquaredHyperplaneTerm(List<LocalVariable> variables, List<Float> coeffs, float constant, float weight) {
        super(variables);
        assert (variables.size() == coeffs.size());
        this.coeffs = coeffs;
        this.constant = constant;
        this.L = null;
        if ((double)weight < 0.0) {
            throw new IllegalArgumentException("Only non-negative weights are supported.");
        }
        this.setWeight(weight);
    }

    private void computeL(float stepSize) {
        if (this.L != null) {
            return;
        }
        DenseFloatMatrix2DWithHashcode matrix = new DenseFloatMatrix2DWithHashcode(this.variables.size(), this.variables.size());
        for (int i = 0; i < this.variables.size(); ++i) {
            for (int j = i; j < this.variables.size(); ++j) {
                float coeff;
                if (i == j) {
                    coeff = 2.0f * this.weight * this.coeffs.get(i).floatValue() * this.coeffs.get(i).floatValue() + stepSize;
                    matrix.setQuick(i, i, coeff);
                    continue;
                }
                coeff = 2.0f * this.weight * this.coeffs.get(i).floatValue() * this.coeffs.get(j).floatValue();
                matrix.setQuick(i, j, coeff);
                matrix.setQuick(j, i, coeff);
            }
        }
        this.L = lCache.get(matrix);
        if (this.L == null) {
            try {
                matrixSemaphore.acquire();
            }
            catch (InterruptedException ex) {
                throw new RuntimeException("Interrupted constructing matrix", ex);
            }
            this.L = new DenseFloatCholeskyDecomposition(matrix).getL();
            lCache.put(matrix, this.L);
            matrixSemaphore.release();
        }
    }

    @Override
    public void setWeight(float weight) {
        this.weight = weight;
        this.L = null;
    }

    protected void minWeightedSquaredHyperplane(float stepSize, float[] consensusValues) {
        int i;
        for (i = 0; i < this.variables.size(); ++i) {
            LocalVariable variable = (LocalVariable)this.variables.get(i);
            float value = stepSize * (consensusValues[variable.getGlobalId()] - variable.getLagrange() / stepSize);
            variable.setValue(value += 2.0f * this.weight * this.coeffs.get(i).floatValue() * this.constant);
        }
        if (this.variables.size() == 1) {
            LocalVariable variable = (LocalVariable)this.variables.get(0);
            float coeff = this.coeffs.get(0).floatValue();
            variable.setValue(variable.getValue() / (2.0f * this.weight * coeff * coeff + stepSize));
            return;
        }
        if (this.variables.size() == 2) {
            LocalVariable variable0 = (LocalVariable)this.variables.get(0);
            LocalVariable variable1 = (LocalVariable)this.variables.get(1);
            float coeff0 = this.coeffs.get(0).floatValue();
            float coeff1 = this.coeffs.get(1).floatValue();
            float a0 = 2.0f * this.weight * coeff0 * coeff0 + stepSize;
            float b1 = 2.0f * this.weight * coeff1 * coeff1 + stepSize;
            float a1b0 = 2.0f * this.weight * coeff0 * coeff1;
            variable1.setValue(variable1.getValue() - a1b0 * variable0.getValue() / a0);
            variable1.setValue(variable1.getValue() / (b1 - a1b0 * a1b0 / a0));
            variable0.setValue((variable0.getValue() - a1b0 * variable1.getValue()) / a0);
            return;
        }
        if (this.L == null) {
            this.computeL(stepSize);
        }
        for (i = 0; i < this.variables.size(); ++i) {
            for (int j = 0; j < i; ++j) {
                ((LocalVariable)this.variables.get(i)).setValue(((LocalVariable)this.variables.get(i)).getValue() - this.L.getQuick(i, j) * ((LocalVariable)this.variables.get(j)).getValue());
            }
            ((LocalVariable)this.variables.get(i)).setValue(((LocalVariable)this.variables.get(i)).getValue() / this.L.getQuick(i, i));
        }
        for (i = this.variables.size() - 1; i >= 0; --i) {
            for (int j = this.variables.size() - 1; j > i; --j) {
                ((LocalVariable)this.variables.get(i)).setValue(((LocalVariable)this.variables.get(i)).getValue() - this.L.getQuick(j, i) * ((LocalVariable)this.variables.get(j)).getValue());
            }
            ((LocalVariable)this.variables.get(i)).setValue(((LocalVariable)this.variables.get(i)).getValue() / this.L.getQuick(i, i));
        }
    }

    private class DenseFloatMatrix2DWithHashcode
    extends DenseFloatMatrix2D {
        private static final long serialVersionUID = -8102931034927566306L;
        private boolean needsNewHashcode;
        private int hashcode;

        public DenseFloatMatrix2DWithHashcode(int rows, int columns) {
            super(rows, columns);
            this.hashcode = 0;
            this.needsNewHashcode = true;
        }

        @Override
        public void setQuick(int row, int column, float value) {
            this.needsNewHashcode = true;
            super.setQuick(row, column, value);
        }

        public int hashCode() {
            if (this.needsNewHashcode) {
                this.hashcode = 17;
                for (int i = 0; i < this.rows(); ++i) {
                    for (int j = 0; j < this.columns(); ++j) {
                        this.hashcode = HashCode.build(this.hashcode, Float.valueOf(this.getQuick(i, j)));
                    }
                }
                this.needsNewHashcode = false;
            }
            return this.hashcode;
        }
    }
}

