/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.algebra.functions.error;

import cz.cvut.fel.ida.algebra.functions.ErrorFcn;
import cz.cvut.fel.ida.algebra.values.ScalarValue;
import cz.cvut.fel.ida.algebra.values.Value;
import cz.cvut.fel.ida.algebra.values.VectorValue;
import java.util.logging.Logger;

public class Crossentropy
implements ErrorFcn {
    private static final Logger LOG = Logger.getLogger(Crossentropy.class.getName());
    static ScalarValue oneHalf = new ScalarValue(0.5);
    static ScalarValue one = new ScalarValue(1.0);
    static ScalarValue minusOne = new ScalarValue(-1.0);
    static double MAXENTVALUE = 100.0;
    static double MAXENTGRADIENT = 1.0E10;
    public static Crossentropy singleton = new Crossentropy();

    @Override
    public Value evaluate(Value output, Value target) {
        if (target instanceof ScalarValue) {
            if (target.greaterThan((Value)oneHalf)) {
                return output.apply(x -> x > 0.0 ? -Math.log(x) : MAXENTVALUE);
            }
            return output.apply(x -> x < 1.0 ? -Math.log(1.0 - x) : MAXENTVALUE);
        }
        VectorValue outputV = (VectorValue)output;
        VectorValue targetV = (VectorValue)target;
        double err = 0.0;
        for (int i = 0; i < outputV.values.length; ++i) {
            err -= targetV.values[i] * (outputV.values[i] > 0.0 ? Math.log(outputV.values[i]) : -MAXENTVALUE);
        }
        return new ScalarValue(err);
    }

    @Override
    public Value differentiate(Value output, Value target) {
        if (target instanceof ScalarValue) {
            if (target.greaterThan((Value)oneHalf)) {
                return output.apply(x -> x > 0.0 ? 1.0 / x : MAXENTGRADIENT);
            }
            return output.apply(x -> x < 1.0 ? -1.0 / (1.0 - x) : -MAXENTGRADIENT);
        }
        VectorValue outputV = (VectorValue)output;
        VectorValue targetV = (VectorValue)target;
        double[] grad = new double[outputV.values.length];
        for (int i = 0; i < outputV.values.length; ++i) {
            grad[i] = targetV.values[i] > 0.5 ? (outputV.values[i] > 0.0 ? 1.0 / outputV.values[i] : MAXENTGRADIENT) : (outputV.values[i] < 1.0 ? -1.0 / (1.0 - outputV.values[i]) : -MAXENTGRADIENT);
        }
        return new VectorValue(grad);
    }

    @Override
    public ErrorFcn getSingleton() {
        return singleton;
    }
}

