/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.application.learning.weight.search;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.linqs.psl.application.learning.weight.search.RandomGridSearch;
import org.linqs.psl.config.ConfigBundle;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.Model;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.util.MathUtils;
import org.linqs.psl.util.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class GuidedRandomGridSearch
extends RandomGridSearch {
    private static final Logger log = LoggerFactory.getLogger(GuidedRandomGridSearch.class);
    public static final String CONFIG_PREFIX = "guidedrandomgridsearch";
    public static final String SEED_LOCATIONS_KEY = "guidedrandomgridsearch.seedlocations";
    public static final int SEED_LOCATIONS_DEFAULT = 25;
    public static final String EXPLORE_LOCATIONS_KEY = "guidedrandomgridsearch.explorelocations";
    public static final int EXPLORE_LOCATIONS_DEFAULT = 10;
    private int numSeedLocations;
    private int numExploreLocations;
    private Set<String> toExplore;

    public GuidedRandomGridSearch(Model model, Database rvDB, Database observedDB, ConfigBundle config) {
        this(model.getRules(), rvDB, observedDB, config);
    }

    public GuidedRandomGridSearch(List<Rule> rules, Database rvDB, Database observedDB, ConfigBundle config) {
        super(rules, rvDB, observedDB, config);
        this.numSeedLocations = config.getInt(SEED_LOCATIONS_KEY, 25);
        if (this.numSeedLocations < 1) {
            throw new IllegalArgumentException("Need at least one location to start the search.");
        }
        this.numExploreLocations = config.getInt(EXPLORE_LOCATIONS_KEY, 10);
        if (this.numExploreLocations < 1) {
            throw new IllegalArgumentException("Need at least one explore location.");
        }
        this.numLocations = Math.min(this.numLocations, this.numSeedLocations + this.numExploreLocations * (int)Math.pow(2.0, this.mutableRules.size()));
        this.toExplore = new HashSet<String>(this.numLocations - this.numSeedLocations);
    }

    @Override
    protected boolean chooseNextLocation() {
        if (this.objectives.size() < this.numSeedLocations) {
            do {
                this.currentLocation = this.randomConfiguration();
            } while (this.objectives.containsKey(this.currentLocation));
        } else {
            if (this.objectives.size() == this.numSeedLocations) {
                ArrayList locations = new ArrayList(this.objectives.entrySet());
                Collections.sort(locations, new Comparator<Map.Entry<String, Double>>(){

                    @Override
                    public int compare(Map.Entry<String, Double> a, Map.Entry<String, Double> b) {
                        return MathUtils.compare(a.getValue(), b.getValue());
                    }
                });
                for (int i = 0; i < Math.min(this.numExploreLocations, this.objectives.size()); ++i) {
                    log.trace("Adding neighbors for {}.", locations.get(i));
                    this.addNeighbors((String)((Map.Entry)locations.get(i)).getKey());
                }
                this.toExplore.removeAll(this.objectives.keySet());
                log.debug("Seed phase complete, starting explore phase with {} locations.", (Object)this.toExplore.size());
            }
            if (this.toExplore.size() == 0) {
                return false;
            }
            this.currentLocation = this.toExplore.iterator().next();
            this.toExplore.remove(this.currentLocation);
        }
        return true;
    }

    private void addNeighbors(String location) {
        int[] indexes = StringUtils.splitInt(location, ":");
        assert (indexes.length == this.mutableRules.size());
        for (int i = 0; i < this.mutableRules.size(); ++i) {
            if (indexes[i] != this.possibleWeights.length - 1) {
                int n = i;
                indexes[n] = indexes[n] + 1;
                this.toExplore.add(StringUtils.join(indexes, ":"));
                int n2 = i;
                indexes[n2] = indexes[n2] - 1;
            }
            if (indexes[i] == 0) continue;
            int n = i;
            indexes[n] = indexes[n] - 1;
            this.toExplore.add(StringUtils.join(indexes, ":"));
            int n3 = i;
            indexes[n3] = indexes[n3] + 1;
        }
    }
}

