/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.application.learning.weight;

import java.util.List;
import org.linqs.psl.application.learning.weight.WeightLearningApplication;
import org.linqs.psl.config.ConfigBundle;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.WeightedRule;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class VotedPerceptron
extends WeightLearningApplication {
    private static final Logger log = LoggerFactory.getLogger(VotedPerceptron.class);
    public static final String CONFIG_PREFIX = "votedperceptron";
    public static final String L2_REGULARIZATION_KEY = "votedperceptron.l2regularization";
    public static final double L2_REGULARIZATION_DEFAULT = 0.0;
    public static final String L1_REGULARIZATION_KEY = "votedperceptron.l1regularization";
    public static final double L1_REGULARIZATION_DEFAULT = 0.0;
    public static final String STEP_SIZE_KEY = "votedperceptron.stepsize";
    public static final double STEP_SIZE_DEFAULT = 1.0;
    public static final String STEP_SCHEDULE_KEY = "votedperceptron.schedule";
    public static final boolean STEP_SCHEDULE_DEFAULT = true;
    public static final String SCALE_GRADIENT_KEY = "votedperceptron.scalegradient";
    public static final boolean SCALE_GRADIENT_DEFAULT = true;
    public static final String AVERAGE_STEPS_KEY = "votedperceptron.averagesteps";
    public static final boolean AVERAGE_STEPS_DEFAULT = true;
    public static final String NUM_STEPS_KEY = "votedperceptron.numsteps";
    public static final int NUM_STEPS_DEFAULT = 25;
    protected final double baseStepSize;
    protected final int numSteps;
    protected final double l2Regularization;
    protected final double l1Regularization;
    protected final boolean scheduleStepSize;
    protected final boolean scaleGradient;
    protected final boolean averageSteps;
    private double currentLoss = Double.POSITIVE_INFINITY;

    public VotedPerceptron(List<Rule> rules, Database rvDB, Database observedDB, boolean supportsLatentVariables, ConfigBundle config) {
        super(rules, rvDB, observedDB, supportsLatentVariables, config);
        this.baseStepSize = config.getDouble(STEP_SIZE_KEY, 1.0);
        if (this.baseStepSize <= 0.0) {
            throw new IllegalArgumentException("Step size must be positive.");
        }
        this.numSteps = config.getInt(NUM_STEPS_KEY, 25);
        if (this.numSteps <= 0) {
            throw new IllegalArgumentException("Number of steps must be positive.");
        }
        this.l2Regularization = config.getDouble(L2_REGULARIZATION_KEY, 0.0);
        if (this.l2Regularization < 0.0) {
            throw new IllegalArgumentException("L2 regularization parameter must be non-negative.");
        }
        this.l1Regularization = config.getDouble(L1_REGULARIZATION_KEY, 0.0);
        if (this.l1Regularization < 0.0) {
            throw new IllegalArgumentException("L1 regularization parameter must be non-negative.");
        }
        this.scheduleStepSize = config.getBoolean(STEP_SCHEDULE_KEY, true);
        this.scaleGradient = config.getBoolean(SCALE_GRADIENT_KEY, true);
        this.averageSteps = config.getBoolean(AVERAGE_STEPS_KEY, true);
    }

    @Override
    protected void doLearn() {
        double[] avgWeights = new double[this.mutableRules.size()];
        this.computeObservedIncompatibility();
        this.setDefaultRandomVariables();
        double[] scalingFactor = this.computeScalingFactor();
        for (int step = 0; step < this.numSteps; ++step) {
            log.debug("Starting iteration {}", (Object)step);
            this.computeExpectedIncompatibility();
            this.currentLoss = this.computeLoss();
            double stepSize = this.getStepSize(step);
            for (int i = 0; i < this.mutableRules.size(); ++i) {
                double weight = ((WeightedRule)this.mutableRules.get(i)).getWeight();
                double currentStep = (this.expectedIncompatibility[i] - this.observedIncompatibility[i] - this.l2Regularization * weight - this.l1Regularization) / scalingFactor[i];
                log.debug("Step of {} for rule {}", (Object)(currentStep *= stepSize), this.mutableRules.get(i));
                log.debug(" --- Expected incomp.: {}, Truth incomp.: {}", (Object)this.expectedIncompatibility[i], (Object)this.observedIncompatibility[i]);
                weight = Math.max(weight + currentStep, 0.0);
                int n = i;
                avgWeights[n] = avgWeights[n] + weight;
                ((WeightedRule)this.mutableRules.get(i)).setWeight(weight);
            }
        }
        if (this.averageSteps) {
            for (int i = 0; i < this.mutableRules.size(); ++i) {
                ((WeightedRule)this.mutableRules.get(i)).setWeight(avgWeights[i] / (double)this.numSteps);
            }
        }
    }

    protected double computeRegularizer() {
        if (this.l1Regularization == 0.0 && this.l2Regularization == 0.0) {
            return 0.0;
        }
        double l2 = 0.0;
        double l1 = 0.0;
        for (WeightedRule rule : this.mutableRules) {
            l2 += Math.pow(rule.getWeight(), 2.0);
            l1 += Math.abs(rule.getWeight());
        }
        return 0.5 * this.l2Regularization * l2 + this.l1Regularization * l1;
    }

    public double getLoss() {
        return this.currentLoss;
    }

    protected double getStepSize(int iteration) {
        if (this.scheduleStepSize) {
            return this.baseStepSize / (double)(iteration + 1);
        }
        return this.baseStepSize;
    }

    protected double[] computeScalingFactor() {
        double[] factor = new double[this.mutableRules.size()];
        for (int i = 0; i < factor.length; ++i) {
            factor[i] = Math.max(1.0, (double)this.groundRuleStore.count((Rule)this.mutableRules.get(i)));
        }
        return factor;
    }
}

