/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.application.learning.weight.search;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.linqs.psl.application.learning.weight.WeightLearningApplication;
import org.linqs.psl.application.learning.weight.search.objective.LossObjective;
import org.linqs.psl.application.learning.weight.search.objective.ObjectiveFunction;
import org.linqs.psl.config.ConfigBundle;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.Model;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.WeightedRule;
import org.linqs.psl.reasoner.admm.ADMMReasoner;
import org.linqs.psl.util.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class GridSearch
extends WeightLearningApplication {
    private static final Logger log = LoggerFactory.getLogger(GridSearch.class);
    public static final String CONFIG_PREFIX = "gridsearch";
    public static final String POSSIBLE_WEIGHTS_KEY = "gridsearch.weights";
    public static final String POSSIBLE_WEIGHTS_DEFAULT = "0.001:0.01:0.1:1:10";
    public static final String OBJECTIVE_KEY = "gridsearch.objective";
    public static final String OBJECTIVE_DEFAULT = LossObjective.class.getName();
    public static final String ADMM_ITERATIONS_KEY = "gridsearch.admmiterations";
    public static final int ADMM_ITERATIONS_DEFAULT = 100;
    public static final String DELIM = ":";
    protected final double[] possibleWeights;
    protected final int admmIterations;
    protected String currentLocation;
    protected int gridSize;
    protected int numLocations;
    protected ObjectiveFunction objectiveFunction;
    protected Map<String, Double> objectives;

    public GridSearch(Model model, Database rvDB, Database observedDB, ConfigBundle config) {
        this(model.getRules(), rvDB, observedDB, config);
    }

    public GridSearch(List<Rule> rules, Database rvDB, Database observedDB, ConfigBundle config) {
        super(rules, rvDB, observedDB, false, config);
        this.possibleWeights = StringUtils.splitDouble(config.getString(POSSIBLE_WEIGHTS_KEY, POSSIBLE_WEIGHTS_DEFAULT), DELIM);
        if (this.possibleWeights.length == 0) {
            throw new IllegalArgumentException("No weights provided for grid search.");
        }
        this.objectiveFunction = (ObjectiveFunction)config.getNewObject(OBJECTIVE_KEY, OBJECTIVE_DEFAULT);
        this.admmIterations = config.getInt(ADMM_ITERATIONS_KEY, 100);
        if (this.admmIterations < 1) {
            throw new IllegalArgumentException("Need at least one iteration for grid search.");
        }
        this.currentLocation = null;
        this.numLocations = this.gridSize = (int)Math.pow(this.possibleWeights.length, this.mutableRules.size());
        this.objectives = new HashMap<String, Double>();
    }

    @Override
    protected void initGroundModel() {
        super.initGroundModel();
        if (this.reasoner instanceof ADMMReasoner) {
            ((ADMMReasoner)this.reasoner).setMaxIter(this.admmIterations);
        }
    }

    @Override
    protected void doLearn() {
        double bestObjective = -1.0;
        double[] bestWeights = new double[this.mutableRules.size()];
        this.computeObservedIncompatibility();
        double[] weights = new double[this.mutableRules.size()];
        for (int iteration = 0; iteration < this.numLocations; ++iteration) {
            if (!this.chooseNextLocation()) {
                log.debug("Stopping search.");
                break;
            }
            log.debug("Iteration {} / {} ({}) -- Inspecting location {}", iteration, this.numLocations, this.gridSize, this.currentLocation);
            this.getWeights(weights);
            for (int i = 0; i < this.mutableRules.size(); ++i) {
                ((WeightedRule)this.mutableRules.get(i)).setWeight(weights[i]);
            }
            this.setDefaultRandomVariables();
            this.computeExpectedIncompatibility();
            double objective = this.objectiveFunction.compute(this.mutableRules, this.observedIncompatibility, this.expectedIncompatibility, this.trainingMap);
            this.objectives.put(this.currentLocation, new Double(objective));
            if (iteration == 0 || objective < bestObjective) {
                bestObjective = objective;
                for (int i = 0; i < this.mutableRules.size(); ++i) {
                    bestWeights[i] = weights[i];
                }
            }
            log.debug("Location {} -- objective: {}", (Object)this.currentLocation, (Object)objective);
        }
        for (int i = 0; i < this.mutableRules.size(); ++i) {
            ((WeightedRule)this.mutableRules.get(i)).setWeight(bestWeights[i]);
        }
    }

    protected void getWeights(double[] weights) {
        int[] indexes = StringUtils.splitInt(this.currentLocation, DELIM);
        assert (indexes.length == this.mutableRules.size());
        for (int i = 0; i < this.mutableRules.size(); ++i) {
            weights[i] = this.possibleWeights[indexes[i]];
        }
    }

    protected boolean chooseNextLocation() {
        if (this.currentLocation == null) {
            this.currentLocation = StringUtils.join(new int[this.mutableRules.size()], DELIM);
            return true;
        }
        int[] indexes = StringUtils.splitInt(this.currentLocation, DELIM);
        assert (indexes.length == this.mutableRules.size());
        for (int i = this.mutableRules.size() - 1; i >= 0; --i) {
            int n = i;
            indexes[n] = indexes[n] + 1;
            if (indexes[i] != this.possibleWeights.length) break;
            indexes[i] = 0;
        }
        this.currentLocation = StringUtils.join(indexes, DELIM);
        return true;
    }
}

