/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.h2o;

import com.google.common.collect.Iterables;
import hex.genmodel.algos.tree.NaSplitDir;
import hex.genmodel.algos.tree.SharedTreeMojoModel;
import hex.genmodel.utils.ByteBufferWrapper;
import hex.genmodel.utils.GenmodelBitSet;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.True;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.tree.CountingBranchNode;
import org.dmg.pmml.tree.CountingLeafNode;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.SimpleNode;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.CategoryManager;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PredicateManager;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import org.jpmml.h2o.Converter;
import org.jpmml.h2o.SharedTree;

public abstract class SharedTreeMojoModelConverter<M extends SharedTreeMojoModel>
extends Converter<M> {
    private static final Field FIELD_COMPRESSEDTREES;
    private static final Field FIELD_COMPRESSEDTREESAUX;
    private static final Field FIELD_NTREEGROUPS;
    private static final Field FIELD_NTREESPERGROUP;

    public SharedTreeMojoModelConverter(M model) {
        super(model);
    }

    public List<TreeModel> encodeTreeModels(Schema schema) {
        SharedTreeMojoModel model = (SharedTreeMojoModel)this.getModel();
        if (model._mojo_version < 1.2) {
            throw new IllegalArgumentException("Version " + model._mojo_version + " is not supported");
        }
        byte[][] compressedTrees = SharedTreeMojoModelConverter.getCompressedTrees(model);
        byte[][] compressedTreesAux = SharedTreeMojoModelConverter.getCompressedTreesAux(model);
        PredicateManager predicateManager = new PredicateManager();
        ArrayList<TreeModel> result = new ArrayList<TreeModel>();
        int max = Math.max(compressedTrees.length, compressedTreesAux.length);
        for (int i = 0; i < max; ++i) {
            final byte[] compressedTree = compressedTrees[i];
            final byte[] compressedTreeAux = compressedTreesAux[i];
            final Map auxInfos = SharedTreeMojoModel.readAuxInfos((byte[])compressedTreeAux);
            SharedTree sharedTree = new SharedTree(){
                private AtomicInteger idSequence = new AtomicInteger(0);

                @Override
                public byte[] getCompressedTree() {
                    return compressedTree;
                }

                @Override
                public byte[] getCompressedTreeAux() {
                    return compressedTreeAux;
                }

                @Override
                public Integer nextId() {
                    return this.idSequence.getAndIncrement();
                }

                @Override
                public SharedTreeMojoModel.AuxInfo getAuxInfo(int id) {
                    return (SharedTreeMojoModel.AuxInfo)auxInfos.get(id);
                }

                @Override
                public void encodeAuxInfo(Node node, double score, double recordCount) {
                    SharedTreeMojoModelConverter.this.ensureScore(node, score);
                    SharedTreeMojoModelConverter.this.ensureRecordCount(node, recordCount);
                }
            };
            TreeModel treeModel = SharedTreeMojoModelConverter.encodeTreeModel(sharedTree, predicateManager, schema);
            result.add(treeModel);
        }
        return result;
    }

    protected void ensureScore(Node node, double score) {
        if (node.hasScore()) {
            if (!Objects.equals(node.getScore(), score)) {
                throw new IllegalArgumentException();
            }
        } else {
            node.setScore((Object)score);
        }
    }

    protected void ensureRecordCount(Node node, double recordCount) {
        if (node.getRecordCount() != null) {
            throw new IllegalArgumentException();
        }
        node.setRecordCount(ValueUtil.narrow((double)recordCount));
    }

    public static TreeModel encodeTreeModel(SharedTree sharedTree, PredicateManager predicateManager, Schema schema) {
        ContinuousLabel label = new ContinuousLabel(DataType.DOUBLE);
        Node root = SharedTreeMojoModelConverter.encodeNode(sharedTree, null, sharedTree.nextId(), (Predicate)True.INSTANCE, new CategoryManager(), predicateManager, schema);
        TreeModel treeModel = new TreeModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema((Label)label), root).setMissingValueStrategy(TreeModel.MissingValueStrategy.DEFAULT_CHILD);
        return treeModel;
    }

    public static Node encodeNode(SharedTree sharedTree, ByteBufferWrapper byteBuffer, Integer id, Predicate predicate, CategoryManager categoryManager, PredicateManager predicateManager, Schema schema) {
        Node rightChild;
        Node leftChild;
        Predicate rightPredicate;
        Predicate leftPredicate;
        byte[] compressedTree = sharedTree.getCompressedTree();
        if (byteBuffer == null) {
            byteBuffer = new ByteBufferWrapper(compressedTree);
        }
        SharedTreeMojoModel.AuxInfo auxInfo = sharedTree.getAuxInfo(id);
        int nodeType = byteBuffer.get1U();
        int lmask = nodeType & 0x33;
        int lmask2 = (nodeType & 0xC0) >> 2;
        int equal = nodeType & 0xC;
        char colId = byteBuffer.get2();
        if (colId == '\uffff') {
            double score = byteBuffer.get4f();
            SimpleNode result = new CountingLeafNode((Object)score, predicate).setId((Object)SharedTreeMojoModelConverter.toNodeId(auxInfo != null, id));
            return result;
        }
        int naSplitDir = byteBuffer.get1U();
        boolean naVsRest = naSplitDir == NaSplitDir.NAvsREST.value();
        boolean leftward = naSplitDir == NaSplitDir.NALeft.value() || naSplitDir == NaSplitDir.Left.value();
        Feature feature = schema.getFeature((int)colId);
        CategoryManager leftCategoryManager = categoryManager;
        CategoryManager rightCategoryManager = categoryManager;
        if (naVsRest) {
            leftPredicate = predicateManager.createSimplePredicate(feature, SimplePredicate.Operator.IS_NOT_MISSING, null);
            rightPredicate = predicateManager.createSimplePredicate(feature, SimplePredicate.Operator.IS_MISSING, null);
        } else if (feature instanceof CategoricalFeature) {
            int i;
            CategoricalFeature categoricalFeature = (CategoricalFeature)feature;
            String name = categoricalFeature.getName();
            List values = categoricalFeature.getValues();
            java.util.function.Predicate valueFilter = categoryManager.getValueFilter(name);
            ArrayList leftValues = new ArrayList();
            ArrayList rightValues = new ArrayList();
            if (equal != 0) {
                GenmodelBitSet bitSet = new GenmodelBitSet(0);
                if (equal == 8) {
                    bitSet.fill2(compressedTree, byteBuffer);
                } else if (equal == 12) {
                    bitSet.fill3(compressedTree, byteBuffer);
                } else {
                    throw new IllegalArgumentException("Node type " + equal + " is not supported");
                }
                for (i = 0; i < values.size(); ++i) {
                    Object value = values.get(i);
                    if (!valueFilter.test(value)) continue;
                    if (bitSet.isInRange(i)) {
                        if (!bitSet.contains(i)) {
                            leftValues.add(value);
                            continue;
                        }
                        rightValues.add(value);
                        continue;
                    }
                    if (leftward) {
                        leftValues.add(value);
                        continue;
                    }
                    rightValues.add(value);
                }
            } else {
                Double splitVal = byteBuffer.get4f();
                for (i = 0; i < values.size(); ++i) {
                    Object value = values.get(i);
                    if (!valueFilter.test(value)) continue;
                    if ((double)i < splitVal) {
                        leftValues.add(value);
                        continue;
                    }
                    rightValues.add(value);
                }
            }
            leftCategoryManager = leftCategoryManager.fork(name, leftValues);
            rightCategoryManager = rightCategoryManager.fork(name, rightValues);
            leftPredicate = predicateManager.createPredicate((Feature)categoricalFeature, leftValues);
            rightPredicate = predicateManager.createPredicate((Feature)categoricalFeature, rightValues);
        } else {
            ContinuousFeature continuousFeature = feature.toContinuousFeature();
            Double splitVal = byteBuffer.get4f();
            leftPredicate = predicateManager.createSimplePredicate((Feature)continuousFeature, SimplePredicate.Operator.LESS_THAN, (Object)splitVal);
            rightPredicate = predicateManager.createSimplePredicate((Feature)continuousFeature, SimplePredicate.Operator.GREATER_OR_EQUAL, (Object)splitVal);
        }
        Integer leftId = auxInfo != null ? auxInfo.nidL : sharedTree.nextId();
        ByteBufferWrapper leftByteBuffer = new ByteBufferWrapper(compressedTree);
        leftByteBuffer.skip(byteBuffer.position());
        if (lmask <= 3) {
            leftByteBuffer.skip(lmask + 1);
        }
        if ((lmask & 0x10) != 0) {
            double score = leftByteBuffer.get4f();
            leftChild = new CountingLeafNode((Object)score, leftPredicate).setId((Object)SharedTreeMojoModelConverter.toNodeId(auxInfo != null, leftId));
        } else {
            leftChild = SharedTreeMojoModelConverter.encodeNode(sharedTree, leftByteBuffer, leftId, leftPredicate, leftCategoryManager, predicateManager, schema);
        }
        Integer rightId = auxInfo != null ? auxInfo.nidR : sharedTree.nextId();
        ByteBufferWrapper rightByteBuffer = new ByteBufferWrapper(compressedTree);
        rightByteBuffer.skip(byteBuffer.position());
        switch (lmask) {
            case 0: {
                rightByteBuffer.skip(rightByteBuffer.get1U());
                break;
            }
            case 1: {
                rightByteBuffer.skip((int)rightByteBuffer.get2());
                break;
            }
            case 2: {
                rightByteBuffer.skip(rightByteBuffer.get3());
                break;
            }
            case 3: {
                rightByteBuffer.skip(rightByteBuffer.get4());
                break;
            }
            case 48: {
                rightByteBuffer.skip(4);
                break;
            }
            default: {
                throw new IllegalArgumentException("Node type " + lmask + " is not supported");
            }
        }
        if ((lmask2 & 0x10) != 0) {
            double score = rightByteBuffer.get4f();
            rightChild = new CountingLeafNode((Object)score, rightPredicate).setId((Object)SharedTreeMojoModelConverter.toNodeId(auxInfo != null, rightId));
        } else {
            rightChild = SharedTreeMojoModelConverter.encodeNode(sharedTree, rightByteBuffer, rightId, rightPredicate, rightCategoryManager, predicateManager, schema);
        }
        if (auxInfo != null) {
            sharedTree.encodeAuxInfo(leftChild, auxInfo.predL, auxInfo.weightL);
            sharedTree.encodeAuxInfo(rightChild, auxInfo.predR, auxInfo.weightR);
        }
        Node result = new CountingBranchNode(null, predicate).setId((Object)SharedTreeMojoModelConverter.toNodeId(auxInfo != null, id)).setDefaultChild(leftward ? leftChild.getId() : rightChild.getId()).addNodes(leftChild, rightChild);
        if (auxInfo != null && id == 0) {
            float weight = auxInfo.weightL + auxInfo.weightR;
            sharedTree.encodeAuxInfo(result, (auxInfo.predL * auxInfo.weightL + auxInfo.predR * auxInfo.weightR) / weight, weight);
        }
        return result;
    }

    public static Model encodeTreeEnsemble(List<TreeModel> treeModels, Function<List<TreeModel>, MiningModel> ensembleFunction) {
        if (treeModels.size() == 1) {
            return (Model)Iterables.getOnlyElement(treeModels);
        }
        return (Model)ensembleFunction.apply(treeModels);
    }

    public static byte[][] getCompressedTrees(SharedTreeMojoModel model) {
        return (byte[][])SharedTreeMojoModelConverter.getFieldValue(FIELD_COMPRESSEDTREES, model);
    }

    public static byte[][] getCompressedTreesAux(SharedTreeMojoModel model) {
        return (byte[][])SharedTreeMojoModelConverter.getFieldValue(FIELD_COMPRESSEDTREESAUX, model);
    }

    public static int getNTreeGroups(SharedTreeMojoModel model) {
        return (Integer)SharedTreeMojoModelConverter.getFieldValue(FIELD_NTREEGROUPS, model);
    }

    public static int getNTreesPerGroup(SharedTreeMojoModel model) {
        return (Integer)SharedTreeMojoModelConverter.getFieldValue(FIELD_NTREESPERGROUP, model);
    }

    private static Integer toNodeId(boolean hasAux, Integer id) {
        return hasAux ? id : id + 1;
    }

    static {
        try {
            FIELD_COMPRESSEDTREES = SharedTreeMojoModel.class.getDeclaredField("_compressed_trees");
            FIELD_COMPRESSEDTREESAUX = SharedTreeMojoModel.class.getDeclaredField("_compressed_trees_aux");
            FIELD_NTREEGROUPS = SharedTreeMojoModel.class.getDeclaredField("_ntree_groups");
            FIELD_NTREESPERGROUP = SharedTreeMojoModel.class.getDeclaredField("_ntrees_per_group");
        }
        catch (ReflectiveOperationException roe) {
            throw new RuntimeException(roe);
        }
    }
}

