/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.learning.results;

import cz.cvut.fel.ida.algebra.functions.ElementWise;
import cz.cvut.fel.ida.algebra.functions.Transformation;
import cz.cvut.fel.ida.algebra.utils.MathUtils;
import cz.cvut.fel.ida.algebra.values.ScalarValue;
import cz.cvut.fel.ida.algebra.values.Value;
import cz.cvut.fel.ida.algebra.values.VectorValue;
import cz.cvut.fel.ida.learning.crossvalidation.MeanStdResults;
import cz.cvut.fel.ida.learning.results.RegressionResults;
import cz.cvut.fel.ida.learning.results.Result;
import cz.cvut.fel.ida.setup.Settings;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;
import java.util.stream.Collectors;

public class ClassificationResults
extends RegressionResults {
    private static final Logger LOG = Logger.getLogger(ClassificationResults.class.getName());
    static Value oneHalf = new ScalarValue(0.5);
    public Double accuracy;
    public Double majorityAcc;
    public Double dispersion;
    private int goodCount;
    protected int zeroCount;
    protected int oneCount;

    public ClassificationResults(List<Result> outputs, Settings settings) {
        super(outputs, settings);
    }

    protected ClassificationResults(Value error, Double accuracy, Double majorityAcc, Double dispersion) {
        super(error);
        this.accuracy = accuracy;
        this.majorityAcc = majorityAcc;
        this.dispersion = dispersion;
    }

    @Override
    public boolean recalculate() {
        this.error = this.calculateErrorValue();
        this.loadBasicCounts(this.evaluations);
        this.loadBinaryMetrics(this.evaluations);
        return true;
    }

    public Value calculateErrorValue() {
        ArrayList<Value> errors = new ArrayList<Value>(this.evaluations.size());
        for (Result evaluation : this.evaluations) {
            errors.add(evaluation.errorValue());
        }
        return this.aggregationFcn.evaluate(errors);
    }

    public static MeanStdResults aggregateClassifications(List<ClassificationResults> resultsList) {
        List<Value> errors = resultsList.stream().map(res -> res.error).collect(Collectors.toList());
        if (errors.isEmpty() || errors.get(0) == null) {
            return null;
        }
        Value meanError = MathUtils.getMeanValue(errors);
        Value stdError = MathUtils.getStd(errors, meanError);
        List<Double> accuracies = resultsList.stream().map(res -> res.accuracy).collect(Collectors.toList());
        Double meanAcc = MathUtils.getMean(accuracies);
        Double stdAcc = MathUtils.getStd(accuracies, meanAcc);
        List<Double> dispersions = resultsList.stream().map(res -> res.dispersion).collect(Collectors.toList());
        Double meanDisp = MathUtils.getMean(dispersions);
        Double stdDisp = MathUtils.getStd(dispersions, meanDisp);
        List<Double> majorErrs = resultsList.stream().map(res -> res.majorityAcc).collect(Collectors.toList());
        Double meanMajErr = MathUtils.getMean(majorErrs);
        Double stdMajErr = MathUtils.getStd(majorErrs, meanMajErr);
        ClassificationResults mean = new ClassificationResults(meanError, meanAcc, meanMajErr, meanDisp);
        ClassificationResults std = new ClassificationResults(stdError, stdAcc, stdMajErr, stdDisp);
        return new MeanStdResults(mean, std);
    }

    private void loadBasicCounts(List<Result> evaluations) {
        if (!(evaluations.get(0).getTarget() instanceof ScalarValue)) {
            return;
        }
        this.zeroCount = 0;
        this.oneCount = 0;
        for (Result evaluation : evaluations) {
            if (evaluation.getTarget().greaterThan(oneHalf)) {
                ++this.oneCount;
                continue;
            }
            ++this.zeroCount;
        }
        this.majorityAcc = (double)Math.max(this.zeroCount, this.oneCount) / (double)evaluations.size();
    }

    private void loadBinaryMetrics(List<Result> evaluations) {
        if (evaluations.get(0).getTarget() instanceof VectorValue) {
            this.loadMulticlassMetrics(evaluations);
            return;
        }
        if (this.settings.squishLastLayer) {
            for (Result evaluation : evaluations) {
                evaluation.setOutput(ElementWise.Singletons.sigmoid.evaluate(evaluation.getOutput()));
            }
        }
        this.goodCount = 0;
        ScalarValue zeroSum = new ScalarValue(0.0);
        ScalarValue oneSum = new ScalarValue(0.0);
        for (Result evaluation : evaluations) {
            if (evaluation.getTarget().greaterThan(oneHalf)) {
                ((Value)oneSum).incrementBy(evaluation.getOutput());
                if (!evaluation.getOutput().greaterThan(oneHalf)) continue;
                ++this.goodCount;
                continue;
            }
            ((Value)zeroSum).incrementBy(evaluation.getOutput());
            if (!oneHalf.greaterThan(evaluation.getOutput())) continue;
            ++this.goodCount;
        }
        Value disp = ((Value)oneSum).elementTimes((Value)new ScalarValue(1.0 / (double)this.oneCount)).minus(((Value)zeroSum).elementTimes((Value)new ScalarValue(1.0 / (double)this.zeroCount)));
        this.dispersion = ((ScalarValue)disp).value;
        this.accuracy = (double)this.goodCount / (double)evaluations.size();
    }

    private void loadMulticlassMetrics(List<Result> evaluations) {
        this.goodCount = 0;
        HashMap<VectorValue, VectorValue> classAcums = new HashMap<VectorValue, VectorValue>();
        HashMap<VectorValue, Integer> classCounts = new HashMap<VectorValue, Integer>();
        if (this.settings.squishLastLayer) {
            for (Result result : evaluations) {
                result.setOutput(Transformation.Singletons.softmax.evaluate(result.getOutput()));
            }
        }
        for (Result result : evaluations) {
            VectorValue value = null;
            try {
                value = (VectorValue)result.getTarget();
            }
            catch (ClassCastException e) {
                LOG.severe("Unsupported target class dimensionality (only scalars or vectors are assumed)");
                return;
            }
            Value classAcum = (Value)classAcums.get(value);
            if (classAcum == null) {
                classAcums.put((VectorValue)result.getTarget(), (VectorValue)result.getOutput().clone());
                classCounts.put((VectorValue)result.getTarget(), 1);
            } else {
                classAcum.incrementBy(result.getOutput());
                classCounts.put((VectorValue)result.getTarget(), (Integer)classCounts.get(result.getTarget()) + 1);
            }
            int maxInd = result.getOutput().getMaxInd();
            if (result.getTarget().getMaxInd() != maxInd || ((VectorValue)result.getOutput()).values[maxInd] == 0.0) continue;
            ++this.goodCount;
        }
        this.dispersion = 0.0;
        for (Map.Entry entry : classAcums.entrySet()) {
            int maxInd = ((VectorValue)entry.getKey()).getMaxInd();
            VectorValue value = (VectorValue)entry.getValue();
            double norm = value.values[maxInd] / (double)((Integer)classCounts.get(entry.getKey())).intValue();
            this.dispersion = this.dispersion + norm;
        }
        this.dispersion = this.dispersion / (double)classCounts.size();
        int maxCount = classCounts.values().stream().mapToInt(Integer::intValue).max().getAsInt();
        this.majorityAcc = (double)maxCount / (double)evaluations.size();
        this.accuracy = (double)this.goodCount / (double)evaluations.size();
    }

    @Override
    public String toString() {
        return this.toString(null);
    }

    @Override
    public String toString(Settings settings) {
        StringBuilder sb = new StringBuilder();
        if (this.accuracy != null) {
            sb.append("accuracy: " + Settings.shortNumberFormat.format(this.accuracy * 100.0) + "%");
        }
        if (this.dispersion != null) {
            sb.append(", disp: " + this.dispersion.toString());
        }
        if (this.error != null) {
            if (settings == null) {
                sb.append(", error: ").append(this.error.toDetailedString());
            } else {
                String errAggfcn = settings.errorAggregationFcn.toString();
                String errFcn = settings.errorFunction.toString();
                String errString = errAggfcn + "(" + errFcn + ")";
                sb.append(", error: ").append(errString).append(" = ").append(this.error.toDetailedString());
            }
        }
        return sb.toString();
    }
}

