/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.learning.results.metrics;

import cz.cvut.fel.ida.algebra.values.Value;
import cz.cvut.fel.ida.learning.results.Result;
import cz.cvut.fel.ida.setup.Settings;
import cz.cvut.fel.ida.utils.generic.tuples.Pair;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

public class HITS {
    private static final Logger LOG = Logger.getLogger(HITS.class.getName());
    private final Random random;
    boolean hitsReifyPredicate;
    Settings.HitsCorruption corruption;
    Settings.HitsPreservation hitsPreservation;
    Settings.HitsClashes hitsClashes;
    LinkedHashSet<String> validSamples;
    LinkedHashSet<String> corruptedSamples;
    Map<String, String[]> samples2terms;
    int maxArity = -1;
    int keepFixedIndex = -1;
    boolean storeCorruptions;
    Map<String, List<String>[]> storedCorruptions;
    Map<String, LinkedHashSet<String>>[] sameTermSamples;
    List<String[]> corruptedTerms;
    Map<String[], String> terms2sample;

    public HITS(List<Result> results, Settings settings) {
        this.random = settings.random;
        this.hitsReifyPredicate = settings.hitsReifyPredicate;
        this.corruption = settings.hitsCorruption;
        this.hitsPreservation = settings.hitsPreservation;
        this.hitsClashes = settings.hitsClashes;
        this.storeCorruptions = settings.storeHitsCorruptions;
        if (this.storeCorruptions) {
            this.storedCorruptions = new HashMap<String, List<String>[]>();
        }
        switch (this.hitsPreservation) {
            case FIRST_STAYS: {
                this.keepFixedIndex = 0;
                break;
            }
            case MIDDLE_STAYS: {
                this.keepFixedIndex = 1;
                break;
            }
            case NONE: {
                this.keepFixedIndex = -1;
            }
        }
        this.samples2terms = this.samples2terms(results);
        Pair<LinkedHashSet<String>, LinkedHashSet<String>> validAndCorrupted = this.validAndCorrupted(results);
        this.validSamples = (LinkedHashSet)validAndCorrupted.r;
        this.corruptedSamples = (LinkedHashSet)validAndCorrupted.s;
        if (this.corruption == Settings.HitsCorruption.ONE_SAME || this.corruption == Settings.HitsCorruption.ALL_DIFF && this.keepFixedIndex >= 0) {
            this.sameTermSamples = this.getDatabase(this.corruptedSamples);
        } else if (this.corruption == Settings.HitsCorruption.ONE_DIFF) {
            this.corruptedTerms = this.corruptedSamples.stream().map(this.samples2terms::get).collect(Collectors.toList());
            this.terms2sample = this.samples2terms.entrySet().stream().collect(Collectors.toMap(Map.Entry::getValue, Map.Entry::getKey));
        }
    }

    public void mergeWith(HITS other) {
        this.samples2terms.putAll(other.samples2terms);
        this.validSamples.addAll(other.validSamples);
        this.corruptedSamples.addAll(other.corruptedSamples);
        if (this.sameTermSamples != null) {
            for (int i = 0; i < this.sameTermSamples.length; ++i) {
                this.sameTermSamples[i].putAll(other.sameTermSamples[i]);
            }
        } else if (this.corruptedTerms != null) {
            this.corruptedTerms.addAll(other.corruptedTerms);
        }
    }

    public Stats getStats(List<Result> results) {
        Map<String, Value> predictions = this.getPredictions(results);
        Stats stats = new Stats();
        for (String validSample : this.validSamples) {
            List<String>[] corruptions;
            Value predictedValue = predictions.get(validSample);
            for (List<String> corruptionsAtIndex : corruptions = this.getCorruptions(validSample)) {
                if (corruptionsAtIndex == null) continue;
                double rank = 1.0;
                if (!corruptionsAtIndex.isEmpty()) {
                    List<Value> corruptedValues = corruptionsAtIndex.stream().map(predictions::get).collect(Collectors.toList());
                    rank = this.getRank(predictedValue, corruptedValues);
                }
                stats.consume(rank);
            }
        }
        return stats.finish();
    }

    private Map<String, Value> getPredictions(List<Result> results) {
        return results.stream().collect(Collectors.toMap(result -> result.sampleId, Result::getOutput));
    }

    private List<String>[] getCorruptions(String etalon) {
        List<String>[] corruptions;
        if (this.storedCorruptions != null && (corruptions = this.storedCorruptions.get(etalon)) != null) {
            return corruptions;
        }
        switch (this.corruption) {
            case ONE_SAME: {
                return this.corruptAtLeastOneSame(etalon);
            }
            case ONE_DIFF: {
                return this.corruptExactlyOneDifferent(etalon);
            }
            case ALL_DIFF: {
                return this.corruptAnything(etalon);
            }
        }
        throw new RuntimeException("Unknown HITS corruption definition");
    }

    public double getRank(Value predictedValue, List<Value> corruptionsAtIndex) {
        int same = 0;
        double rank = 1.0;
        if (corruptionsAtIndex != null) {
            for (Value corrupt : corruptionsAtIndex) {
                if (corrupt.greaterThan(predictedValue)) {
                    rank += 1.0;
                    continue;
                }
                if (predictedValue.greaterThan(corrupt)) continue;
                ++same;
            }
        }
        switch (this.hitsClashes) {
            case AVG: {
                return rank + (double)same / 2.0;
            }
            case NONE: {
                return rank;
            }
            case RANDOM: {
                return rank + (double)this.random.nextInt(same + 1);
            }
        }
        return rank;
    }

    private Map<String, String[]> samples2terms(List<Result> results) {
        HashMap<String, String[]> samples2terms = new HashMap<String, String[]>();
        for (Result result : results) {
            int separator = -1;
            if (this.hitsReifyPredicate) {
                if (this.hitsPreservation == Settings.HitsPreservation.MIDDLE_STAYS) {
                    LOG.warning("Including predicate, but also probably assuming it is reified in the middle of the terms.");
                }
                separator = result.sampleId.indexOf(":") + 1;
            } else {
                separator = result.sampleId.indexOf("(") + 1;
            }
            String[] terms = result.sampleId.substring(separator > 0 ? separator : 0).replace("(", ",").replace(")", "").replace(" ", "").split(",");
            if (this.hitsPreservation == Settings.HitsPreservation.MIDDLE_STAYS && terms.length != 3) {
                LOG.severe("trying to calculate predicate-in-the-middle HITs from literal of length: " + terms.length);
            }
            samples2terms.put(result.sampleId, terms);
            if (terms.length <= this.maxArity) continue;
            this.maxArity = terms.length;
        }
        return samples2terms;
    }

    public Pair<LinkedHashSet<String>, LinkedHashSet<String>> validAndCorrupted(List<Result> results) {
        LinkedHashSet<String> corrupted = new LinkedHashSet<String>();
        LinkedHashSet<String> valid = new LinkedHashSet<String>();
        for (Result result : results) {
            if (result.getTarget().greaterThan(Value.ZERO)) {
                valid.add(result.sampleId);
                continue;
            }
            corrupted.add(result.sampleId);
        }
        return new Pair<LinkedHashSet<String>, LinkedHashSet<String>>(valid, corrupted);
    }

    private List<String>[] corruptExactlyOneDifferent(String etalon) {
        String[] etalonTerms = this.samples2terms.get(etalon);
        ArrayList[] corruptions = new ArrayList[etalonTerms.length];
        List<String>[] leftSeries = this.precalculateIntersections(etalonTerms, true);
        List<String>[] rightSeries = this.precalculateIntersections(etalonTerms, false);
        for (int i = 0; i < etalonTerms.length; ++i) {
            if (i == this.keepFixedIndex) continue;
            if (i - 1 >= 0) {
                corruptions[i] = new ArrayList<String>(leftSeries[i - 1]);
            }
            if (i + 1 >= etalonTerms.length) continue;
            if (corruptions[i] != null) {
                corruptions[i].retainAll(rightSeries[i + 1]);
                continue;
            }
            corruptions[i] = new ArrayList<String>(rightSeries[i + 1]);
        }
        if (this.storeCorruptions) {
            this.storedCorruptions.put(etalon, corruptions);
        }
        return corruptions;
    }

    private List<String>[] precalculateIntersections(String[] etalon, boolean left) {
        IntStream intStream = left ? IntStream.rangeClosed(0, etalon.length - 1) : IntStream.rangeClosed(0, etalon.length - 1).map(i -> etalon.length - 1 - i);
        List[] series = new ArrayList[etalon.length];
        List[] previous = new ArrayList[]{this.corruptedTerms};
        intStream.forEach(i -> {
            if (previous[0] != null) {
                series[i] = previous[0].stream().filter(terms -> etalon[i].equals(terms[i])).collect(Collectors.toList());
            }
            previous[0] = series[i];
        });
        ArrayList[] corruptions = new ArrayList[etalon.length];
        for (int i2 = 0; i2 < corruptions.length; ++i2) {
            corruptions[i2] = series[i2].stream().map(this.terms2sample::get).collect(Collectors.toList());
        }
        return corruptions;
    }

    private List<String>[] corruptAtLeastOneSame(String etalon) {
        String[] etalonTerms = this.samples2terms.get(etalon);
        ArrayList[] corruptions = new ArrayList[etalonTerms.length];
        for (int i = 0; i < etalonTerms.length; ++i) {
            if (i == this.keepFixedIndex) continue;
            LinkedHashSet<String> corruptsWithSameTerm = this.sameTermSamples[i].get(etalonTerms[i]);
            if (this.keepFixedIndex >= 0) {
                LinkedHashSet<String> fixedSubset = this.sameTermSamples[this.keepFixedIndex].get(etalonTerms[this.keepFixedIndex]);
                if (fixedSubset == null) {
                    corruptions[i] = new ArrayList();
                    continue;
                }
                LinkedHashSet<String> corruptedSubSelection = new LinkedHashSet<String>(fixedSubset);
                corruptedSubSelection.retainAll(corruptsWithSameTerm);
                corruptions[i] = new ArrayList<String>(corruptedSubSelection);
                continue;
            }
            corruptions[i] = corruptsWithSameTerm != null ? new ArrayList<String>(corruptsWithSameTerm) : new ArrayList();
        }
        if (this.storeCorruptions) {
            this.storedCorruptions.put(etalon, corruptions);
        }
        return corruptions;
    }

    private Map<String, LinkedHashSet<String>>[] getDatabase(LinkedHashSet<String> corrupted) {
        HashMap[] corruptedDatabase = new HashMap[this.maxArity];
        for (int i = 0; i < corruptedDatabase.length; ++i) {
            corruptedDatabase[i] = new HashMap();
        }
        for (String sample : corrupted) {
            String[] terms = this.samples2terms.get(sample);
            for (int i = 0; i < terms.length; ++i) {
                LinkedHashSet<String> otherSamplesWithTermAtIndex = (LinkedHashSet<String>)corruptedDatabase[i].get(terms[i]);
                if (otherSamplesWithTermAtIndex != null) {
                    otherSamplesWithTermAtIndex.add(sample);
                    continue;
                }
                otherSamplesWithTermAtIndex = new LinkedHashSet<String>();
                otherSamplesWithTermAtIndex.add(sample);
                corruptedDatabase[i].put(terms[i], otherSamplesWithTermAtIndex);
            }
        }
        return corruptedDatabase;
    }

    private List<String>[] corruptAnything(String etalon) {
        if (this.storedCorruptions != null && !this.storedCorruptions.isEmpty()) {
            return this.storedCorruptions.get("");
        }
        if (this.keepFixedIndex >= 0) {
            String[] etalonTerms = this.samples2terms.get(etalon);
            LinkedHashSet<String> sameFixedCorruptions = this.sameTermSamples[this.keepFixedIndex].get(etalonTerms[this.keepFixedIndex]);
            if (sameFixedCorruptions != null) {
                return new List[]{new ArrayList<String>(sameFixedCorruptions)};
            }
            return new List[]{new ArrayList()};
        }
        ArrayList[] corruptions = new ArrayList[]{new ArrayList<String>(this.corruptedSamples)};
        if (this.storeCorruptions) {
            this.storedCorruptions.put("", corruptions);
        }
        return corruptions;
    }

    public static class Stats {
        double MRR;
        double AVGrank;
        int[] HITSindices = new int[]{1, 3, 5, 10};
        double[] HITSresults = new double[this.HITSindices.length];
        private int counter = 0;
        private boolean finalized = false;

        public Stats() {
        }

        public Stats(int[] HITSindices) {
            this.HITSindices = HITSindices;
            this.HITSresults = new double[HITSindices.length];
        }

        public void consume(double rank) {
            this.AVGrank += rank;
            this.MRR += 1.0 / rank;
            for (int i = 0; i < this.HITSindices.length; ++i) {
                if (!(rank <= (double)this.HITSindices[i])) continue;
                int n = i;
                this.HITSresults[n] = this.HITSresults[n] + 1.0;
            }
            ++this.counter;
        }

        public Stats finish() {
            if (this.finalized) {
                return this;
            }
            this.AVGrank /= (double)this.counter;
            this.MRR /= (double)this.counter;
            int i = 0;
            while (i < this.HITSindices.length) {
                int n = i++;
                this.HITSresults[n] = this.HITSresults[n] / (double)this.counter;
            }
            this.finalized = true;
            return this;
        }

        public String toString() {
            if (!this.finalized) {
                return "stats are not finalized yet!";
            }
            StringBuilder sb = new StringBuilder();
            sb.append("MRR=").append(Settings.shortNumberFormat.format(this.MRR));
            sb.append(", MeanRank=").append(Settings.shortNumberFormat.format(this.AVGrank));
            sb.append(", HITS at ").append(Arrays.toString(this.HITSindices)).append("=(");
            for (double hitSresult : this.HITSresults) {
                sb.append(Settings.shortNumberFormat.format(hitSresult)).append(",");
            }
            sb.replace(sb.length() - 1, sb.length(), ")");
            return sb.toString();
        }
    }
}

