/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.model.rule.arithmetic;

import com.healthmarketscience.sqlbuilder.BinaryCondition;
import com.healthmarketscience.sqlbuilder.CustomSql;
import com.healthmarketscience.sqlbuilder.SelectQuery;
import com.healthmarketscience.sqlbuilder.SetOperationQuery;
import com.healthmarketscience.sqlbuilder.Subquery;
import com.healthmarketscience.sqlbuilder.UnionQuery;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.linqs.psl.application.groundrulestore.GroundRuleStore;
import org.linqs.psl.database.DatabaseQuery;
import org.linqs.psl.database.ResultList;
import org.linqs.psl.database.atom.AtomManager;
import org.linqs.psl.database.rdbms.Formula2SQL;
import org.linqs.psl.database.rdbms.RDBMSDataStore;
import org.linqs.psl.database.rdbms.RDBMSDatabase;
import org.linqs.psl.database.rdbms.driver.DatabaseDriver;
import org.linqs.psl.model.atom.Atom;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.atom.QueryAtom;
import org.linqs.psl.model.formula.Conjunction;
import org.linqs.psl.model.formula.Disjunction;
import org.linqs.psl.model.formula.Formula;
import org.linqs.psl.model.formula.Negation;
import org.linqs.psl.model.predicate.Predicate;
import org.linqs.psl.model.predicate.StandardPredicate;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.arithmetic.AbstractGroundArithmeticRule;
import org.linqs.psl.model.rule.arithmetic.expression.ArithmeticRuleExpression;
import org.linqs.psl.model.rule.arithmetic.expression.SummationAtom;
import org.linqs.psl.model.rule.arithmetic.expression.SummationAtomOrAtom;
import org.linqs.psl.model.rule.arithmetic.expression.SummationVariable;
import org.linqs.psl.model.rule.arithmetic.expression.SummationVariableOrTerm;
import org.linqs.psl.model.term.Constant;
import org.linqs.psl.model.term.ConstantType;
import org.linqs.psl.model.term.StringAttribute;
import org.linqs.psl.model.term.Variable;
import org.linqs.psl.model.term.VariableTypeMap;
import org.linqs.psl.reasoner.function.FunctionComparator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class AbstractArithmeticRule
implements Rule {
    private static final Logger log = LoggerFactory.getLogger(AbstractArithmeticRule.class);
    private static String DELIM = ";";
    protected final ArithmeticRuleExpression expression;
    protected final Map<SummationVariable, Formula> filters;

    public AbstractArithmeticRule(ArithmeticRuleExpression expression, Map<SummationVariable, Formula> filterClauses) {
        this.expression = expression;
        this.filters = filterClauses;
        for (Map.Entry<SummationVariable, Formula> entry : this.filters.entrySet()) {
            entry.setValue(entry.getValue().getDNF());
        }
        this.validateRule();
    }

    public boolean hasSummation() {
        return this.expression.getSummationVariables().size() > 0;
    }

    public ArithmeticRuleExpression getExpression() {
        return this.expression;
    }

    @Override
    public int groundAll(AtomManager atomManager, GroundRuleStore groundRuleStore) {
        this.validateGroundRule(atomManager);
        int groundCount = 0;
        groundCount = this.expression.getSummationVariables().size() == 0 ? this.groundNonSummationRule(atomManager, groundRuleStore) : this.groundSummationRule(atomManager, groundRuleStore);
        log.debug("Grounded {} instances of rule {}", (Object)groundCount, (Object)this);
        return groundCount;
    }

    public Set<Predicate> getBodyPredicates() {
        HashSet<Predicate> predicates = new HashSet<Predicate>();
        for (SummationAtomOrAtom atom : this.expression.getAtoms()) {
            if (atom instanceof SummationAtom) {
                predicates.add(((SummationAtom)atom).getPredicate());
                continue;
            }
            predicates.add(((Atom)atom).getPredicate());
        }
        return predicates;
    }

    private int groundNonSummationRule(AtomManager atomManager, GroundRuleStore groundRuleStore) {
        ResultList groundVariables = atomManager.executeQuery(new DatabaseQuery(this.expression.getQueryFormula(), false));
        return this.groundNonSummationRule(groundVariables, atomManager, groundRuleStore);
    }

    public int groundNonSummationRule(ResultList groundVariables, AtomManager atomManager, GroundRuleStore groundRuleStore) {
        ArrayList<QueryAtom> queryAtoms = new ArrayList<QueryAtom>();
        for (SummationAtomOrAtom atom : this.expression.getAtoms()) {
            queryAtoms.add((QueryAtom)atom);
        }
        GroundAtom[] groundAtoms = new GroundAtom[queryAtoms.size()];
        double[] coefficients = new double[queryAtoms.size()];
        for (int i = 0; i < coefficients.length; ++i) {
            coefficients[i] = this.expression.getAtomCoefficients().get(i).getValue(null);
        }
        double finalCoefficient = this.expression.getFinalCoefficient().getValue(null);
        int groundCount = 0;
        for (int groundingIndex = 0; groundingIndex < groundVariables.size(); ++groundingIndex) {
            for (int atomIndex = 0; atomIndex < groundAtoms.length; ++atomIndex) {
                groundAtoms[atomIndex] = ((QueryAtom)queryAtoms.get(atomIndex)).ground(atomManager, groundVariables, groundingIndex);
            }
            if (this.isWeighted() && FunctionComparator.Equality.equals((Object)this.expression.getComparator())) {
                groundRuleStore.addGroundRule(this.makeGroundRule(coefficients, groundAtoms, FunctionComparator.LargerThan, finalCoefficient));
                groundRuleStore.addGroundRule(this.makeGroundRule(coefficients, groundAtoms, FunctionComparator.SmallerThan, finalCoefficient));
                groundCount += 2;
                continue;
            }
            groundRuleStore.addGroundRule(this.makeGroundRule(coefficients, groundAtoms, this.expression.getComparator(), finalCoefficient));
            ++groundCount;
        }
        return groundCount;
    }

    private int groundSummationRule(AtomManager atomManager, GroundRuleStore groundRuleStore) {
        if (!(atomManager.getDatabase() instanceof RDBMSDatabase)) {
            throw new IllegalArgumentException("Can only ground summation arithmetic rules with a relational database.");
        }
        RDBMSDatabase relationalDB = (RDBMSDatabase)atomManager.getDatabase();
        HashMap<Variable, Integer> projectionMap = new HashMap<Variable, Integer>();
        VariableTypeMap varTypes = new VariableTypeMap();
        UnionQuery subquery = this.buildCoreSummationQuery(relationalDB, projectionMap, varTypes);
        SelectQuery query = this.buildAggregateSummationQuery(projectionMap, subquery, ((RDBMSDataStore)relationalDB.getDataStore()).getDriver());
        VariableTypeMap fakeTypes = new VariableTypeMap();
        fakeTypes.addAll(varTypes);
        for (SummationVariable summationVar : this.expression.getSummationVariables()) {
            fakeTypes.addVariable(summationVar.getVariable(), ConstantType.String, true);
        }
        ResultList groundingResults = relationalDB.executeQuery(projectionMap, fakeTypes, ((SelectQuery)query.validate()).toString());
        return this.instantiateSumamtionGroundRules(groundingResults, varTypes, atomManager, groundRuleStore);
    }

    private int instantiateSumamtionGroundRules(ResultList groundingResults, VariableTypeMap varTypes, AtomManager atomManager, GroundRuleStore groundRuleStore) {
        int groundCount = 0;
        ArrayList<GroundAtom> groundAtoms = new ArrayList<GroundAtom>();
        ArrayList<Double> coefficients = new ArrayList<Double>();
        for (int groundingIndex = 0; groundingIndex < groundingResults.size(); ++groundingIndex) {
            groundAtoms.clear();
            coefficients.clear();
            HashMap<SummationVariable, Constant[]> subs = new HashMap<SummationVariable, Constant[]>();
            HashMap<SummationVariable, Integer> subCounts = new HashMap<SummationVariable, Integer>();
            for (SummationVariable summationVar : this.expression.getSummationVariables()) {
                Constant rawSubs = groundingResults.get(groundingIndex, summationVar.getVariable());
                String[] stringSubs = ((StringAttribute)rawSubs).getValue().split(DELIM);
                Constant[] constantSubs = new Constant[stringSubs.length];
                for (int i = 0; i < stringSubs.length; ++i) {
                    constantSubs[i] = ConstantType.getConstant(stringSubs[i], varTypes.getType(summationVar.getVariable()));
                }
                subs.put(summationVar, constantSubs);
                subCounts.put(summationVar, constantSubs.length);
            }
            for (int i = 0; i < this.expression.getAtoms().size(); ++i) {
                SummationAtomOrAtom atom = this.expression.getAtoms().get(i);
                double coefficientValue = this.expression.getAtomCoefficients().get(i).getValue(subCounts);
                if (atom instanceof SummationAtom) {
                    Constant[] args = new Constant[((SummationAtom)atom).getArity()];
                    this.instantiateSummationVariables((SummationAtom)atom, args, 0, coefficientValue, subs, atomManager, groundingResults, groundingIndex, groundAtoms, coefficients);
                    continue;
                }
                groundAtoms.add(((QueryAtom)atom).ground(atomManager, groundingResults, groundingIndex));
                coefficients.add(coefficientValue);
            }
            double finalCoefficient = this.expression.getFinalCoefficient().getValue(subCounts);
            if (this.isWeighted() && FunctionComparator.Equality.equals((Object)this.expression.getComparator())) {
                groundRuleStore.addGroundRule(this.makeGroundRule(coefficients, groundAtoms, FunctionComparator.LargerThan, finalCoefficient));
                groundRuleStore.addGroundRule(this.makeGroundRule(coefficients, groundAtoms, FunctionComparator.SmallerThan, finalCoefficient));
                groundCount += 2;
                continue;
            }
            groundRuleStore.addGroundRule(this.makeGroundRule(coefficients, groundAtoms, this.expression.getComparator(), finalCoefficient));
            ++groundCount;
        }
        return groundCount;
    }

    private void instantiateSummationVariables(SummationAtom atom, Constant[] args, int argIndex, double coefficientValue, Map<SummationVariable, Constant[]> subs, AtomManager atomManager, ResultList groundingResults, int groundingIndex, List<GroundAtom> groundAtoms, List<Double> coefficients) {
        if (argIndex == args.length) {
            if (atomManager.getDatabase().hasAtom((StandardPredicate)atom.getPredicate(), args)) {
                groundAtoms.add(atomManager.getAtom(atom.getPredicate(), args));
                coefficients.add(coefficientValue);
            }
            return;
        }
        SummationVariableOrTerm arg = atom.getArguments()[argIndex];
        if (arg instanceof Variable) {
            args[argIndex] = groundingResults.get(groundingIndex, (Variable)arg);
            this.instantiateSummationVariables(atom, args, argIndex + 1, coefficientValue, subs, atomManager, groundingResults, groundingIndex, groundAtoms, coefficients);
        } else if (arg instanceof Constant) {
            args[argIndex] = (Constant)arg;
            this.instantiateSummationVariables(atom, args, argIndex + 1, coefficientValue, subs, atomManager, groundingResults, groundingIndex, groundAtoms, coefficients);
        } else {
            Constant[] arr$ = subs.get((SummationVariable)arg);
            int len$ = arr$.length;
            for (int i$ = 0; i$ < len$; ++i$) {
                Constant sub;
                args[argIndex] = sub = arr$[i$];
                this.instantiateSummationVariables(atom, args, argIndex + 1, coefficientValue, subs, atomManager, groundingResults, groundingIndex, groundAtoms, coefficients);
            }
        }
    }

    private SelectQuery buildAggregateSummationQuery(Map<Variable, Integer> projectionMap, UnionQuery subquery, DatabaseDriver driver) {
        SelectQuery query = new SelectQuery();
        String[] columns = new String[projectionMap.size()];
        for (Variable var : this.expression.getVariables()) {
            columns[projectionMap.get((Object)var).intValue()] = var.getName();
        }
        for (SummationVariable summationVar : this.expression.getSummationVariables()) {
            String column;
            Variable var = summationVar.getVariable();
            String aggExpression = driver.getStringAggregate(var.getName(), DELIM, true);
            columns[projectionMap.get((Object)var).intValue()] = column = aggExpression + " AS " + var.getName();
        }
        for (String column : columns) {
            query.addCustomColumns(new CustomSql(column));
        }
        query.addCustomFromTable(new Subquery(subquery).toString() + " X");
        for (Variable var : this.expression.getVariables()) {
            query.addCustomGroupings(var.getName());
        }
        return query;
    }

    private UnionQuery buildCoreSummationQuery(RDBMSDatabase relationalDB, Map<Variable, Integer> outProjectionMap, VariableTypeMap outVarTypes) {
        Formula bodyFormula = this.expression.getQueryFormula();
        Set<Variable> projectionSet = bodyFormula.collectVariables(new VariableTypeMap()).keySet();
        Map<Variable, Integer> projectionMap = null;
        ArrayList<SelectQuery> queries = new ArrayList<SelectQuery>();
        Formula2SQL sqler = new Formula2SQL(projectionSet, relationalDB, false);
        SelectQuery bodyQuery = sqler.getQuery(bodyFormula);
        projectionMap = sqler.getProjectionMap();
        this.collectFilterQueries(queries, projectionSet, relationalDB, bodyFormula, this.filters.values().toArray(new Formula[0]), 0, null);
        bodyFormula.collectVariables(outVarTypes);
        outProjectionMap.putAll(projectionMap);
        return new UnionQuery(SetOperationQuery.Type.UNION, queries.toArray(new SelectQuery[0]));
    }

    private void collectFilterQueries(List<SelectQuery> queries, Set<Variable> projectionSet, RDBMSDatabase relationalDB, Formula baseFormula, Formula[] filterFormulas, int formulaIndex, Formula appendedFormulas) {
        if (formulaIndex == filterFormulas.length) {
            Formula2SQL sqler = new Formula2SQL(projectionSet, relationalDB, false);
            SelectQuery query = sqler.getQuery(baseFormula);
            if (appendedFormulas != null) {
                Map<Atom, String> tableAliases = sqler.getTableAliases();
                this.addFilterConditions(appendedFormulas, query, tableAliases);
            }
            queries.add(query);
            return;
        }
        if (filterFormulas[formulaIndex] instanceof Disjunction) {
            Disjunction disjunction = (Disjunction)filterFormulas[formulaIndex];
            for (int i = 0; i < disjunction.length(); ++i) {
                Formula formula = this.stableQueryConjunction(baseFormula, disjunction.get(i));
                Formula newAppendedFormuals = null;
                newAppendedFormuals = appendedFormulas == null ? disjunction.get(i) : this.stableQueryConjunction(appendedFormulas, disjunction.get(i));
                this.collectFilterQueries(queries, projectionSet, relationalDB, formula, filterFormulas, formulaIndex + 1, newAppendedFormuals);
            }
        } else {
            Formula formula = this.stableQueryConjunction(baseFormula, filterFormulas[formulaIndex]);
            appendedFormulas = appendedFormulas == null ? filterFormulas[formulaIndex] : this.stableQueryConjunction(appendedFormulas, filterFormulas[formulaIndex]);
            this.collectFilterQueries(queries, projectionSet, relationalDB, formula, filterFormulas, formulaIndex + 1, appendedFormulas);
        }
    }

    private Formula stableQueryConjunction(Formula a, Formula b) {
        int i;
        Conjunction conjunction;
        ArrayList<Formula> components = new ArrayList<Formula>();
        if (a instanceof Conjunction) {
            conjunction = (Conjunction)a;
            for (i = 0; i < conjunction.length(); ++i) {
                components.add(conjunction.get(i));
            }
        } else if (a instanceof Negation) {
            components.add(((Negation)a).getFormula());
        } else {
            components.add(a);
        }
        if (b instanceof Conjunction) {
            conjunction = (Conjunction)b;
            for (i = 0; i < conjunction.length(); ++i) {
                components.add(conjunction.get(i));
            }
        } else if (b instanceof Negation) {
            components.add(((Negation)b).getFormula());
        } else {
            components.add(b);
        }
        return new Conjunction(components.toArray(new Formula[0]));
    }

    private void addFilterConditions(Formula filterFormula, SelectQuery query, Map<Atom, String> tableAliases) {
        if (filterFormula instanceof Atom) {
            CustomSql valueColumn = new CustomSql(tableAliases.get((Atom)filterFormula) + "." + "value");
            query.addCondition(BinaryCondition.greaterThan(valueColumn, 0.0));
        } else if (filterFormula instanceof Negation) {
            Atom atom = (Atom)((Negation)filterFormula).getFormula();
            CustomSql valueColumn = new CustomSql(tableAliases.get(atom) + "." + "value");
            query.addCondition(BinaryCondition.equalTo(valueColumn, 0.0));
        } else if (filterFormula instanceof Conjunction) {
            Conjunction conjunction = (Conjunction)filterFormula;
            for (int i = 0; i < conjunction.length(); ++i) {
                this.addFilterConditions(conjunction.get(i), query, tableAliases);
            }
        } else {
            throw new IllegalStateException("Unexpected formula type: " + filterFormula.getClass().getName());
        }
    }

    private void validateRule() {
        for (SummationVariable filterArg : this.filters.keySet()) {
            if (this.expression.getSummationVariables().contains(filterArg)) continue;
            throw new IllegalArgumentException(String.format("Unknown variable (%s) used as filter argument. All filter arguments must appear as summation variables in associated arithmetic expression.", filterArg.getVariable().getName()));
        }
        HashSet<String> expressionVariableNames = new HashSet<String>();
        for (Variable variable : this.expression.getVariables()) {
            expressionVariableNames.add(variable.getName());
        }
        for (Map.Entry entry : this.filters.entrySet()) {
            VariableTypeMap filterVars = new VariableTypeMap();
            ((Formula)entry.getValue()).collectVariables(filterVars);
            for (Variable var : filterVars.keySet()) {
                if (((SummationVariable)entry.getKey()).getVariable().getName().equals(var.getName()) || expressionVariableNames.contains(var.getName())) continue;
                throw new IllegalArgumentException(String.format("Unknown variable (%s) used in filter. All filter variables must either be the filter argument or appear in the associated arithmetic expression.", var.getName()));
            }
        }
    }

    public void validateGroundRule(AtomManager atomManager) {
        HashSet<Atom> filterAtoms = new HashSet<Atom>();
        for (Formula filter : this.filters.values()) {
            filter.getAtoms(filterAtoms);
        }
        for (Atom filterAtom : filterAtoms) {
            if (!(filterAtom.getPredicate() instanceof StandardPredicate) || atomManager.isClosed((StandardPredicate)filterAtom.getPredicate())) continue;
            throw new IllegalArgumentException(String.format("Open predicate (%s) not allowed in filter. Only closed predicates may appear in filters.", filterAtom.getPredicate().getName()));
        }
    }

    protected abstract AbstractGroundArithmeticRule makeGroundRule(double[] var1, GroundAtom[] var2, FunctionComparator var3, double var4);

    protected abstract AbstractGroundArithmeticRule makeGroundRule(List<Double> var1, List<GroundAtom> var2, FunctionComparator var3, double var4);

    public static void setDelim(String delim) {
        DELIM = delim;
    }

    public int hashCode() {
        return this.expression.hashCode();
    }

    public boolean equals(Object other) {
        if (this == other) {
            return true;
        }
        if (other == null || !(other instanceof AbstractArithmeticRule)) {
            return false;
        }
        AbstractArithmeticRule otherRule = (AbstractArithmeticRule)other;
        return this.filters.equals(otherRule.filters) && this.expression.equals(otherRule.expression);
    }
}

