/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.algebra.functions;

import cz.cvut.fel.ida.algebra.functions.Transformation;
import cz.cvut.fel.ida.algebra.functions.transformation.elementwise.Exponentiation;
import cz.cvut.fel.ida.algebra.functions.transformation.elementwise.Inverse;
import cz.cvut.fel.ida.algebra.functions.transformation.elementwise.LeakyReLu;
import cz.cvut.fel.ida.algebra.functions.transformation.elementwise.Logarithm;
import cz.cvut.fel.ida.algebra.functions.transformation.elementwise.LukasiewiczSigmoid;
import cz.cvut.fel.ida.algebra.functions.transformation.elementwise.ReLu;
import cz.cvut.fel.ida.algebra.functions.transformation.elementwise.Reverse;
import cz.cvut.fel.ida.algebra.functions.transformation.elementwise.Sigmoid;
import cz.cvut.fel.ida.algebra.functions.transformation.elementwise.Signum;
import cz.cvut.fel.ida.algebra.functions.transformation.elementwise.SquareRoot;
import cz.cvut.fel.ida.algebra.functions.transformation.elementwise.Tanh;
import cz.cvut.fel.ida.algebra.values.Value;
import cz.cvut.fel.ida.setup.Settings;
import java.util.function.DoubleUnaryOperator;
import java.util.logging.Logger;

public abstract class ElementWise
implements Transformation {
    private static final Logger LOG = Logger.getLogger(ElementWise.class.getName());
    transient DoubleUnaryOperator evaluation;
    transient DoubleUnaryOperator gradient;

    protected ElementWise(DoubleUnaryOperator evaluation, DoubleUnaryOperator gradient) {
        this.evaluation = evaluation;
        this.gradient = gradient;
    }

    @Override
    public Value evaluate(Value combinedInputs) {
        return combinedInputs.apply(this.evaluation);
    }

    @Override
    public Value differentiate(Value combinedInputs) {
        return combinedInputs.apply(this.gradient);
    }

    public static ElementWise getFunction(Settings.TransformationFcn activationFcn) {
        switch (activationFcn) {
            case SIGMOID: {
                return Singletons.sigmoid;
            }
            case TANH: {
                return Singletons.tanh;
            }
            case SIGNUM: {
                return Singletons.signum;
            }
            case RELU: {
                return Singletons.relu;
            }
            case LEAKYRELU: {
                return Singletons.leakyRelu;
            }
            case LUKASIEWICZ: {
                return Singletons.lukasiewiczSigmoid;
            }
            case EXP: {
                return Singletons.exponentiation;
            }
            case SQRT: {
                return Singletons.sqrt;
            }
            case INVERSE: {
                return Singletons.inverse;
            }
            case REVERSE: {
                return Singletons.reverse;
            }
            case LOGARITHM: {
                return Singletons.logarithm;
            }
        }
        return null;
    }

    @Override
    public Transformation.State getState(boolean singleInput) {
        return new State(this);
    }

    public static class State
    extends Transformation.State {
        public State(ElementWise elementWise) {
            super(elementWise);
        }

        @Override
        public void ingestTopGradient(Value topGradient) {
            Value inputFcnDerivative = this.gradient();
            this.processedGradient = topGradient.elementTimes(inputFcnDerivative);
        }
    }

    public static class Singletons {
        public static LukasiewiczSigmoid lukasiewiczSigmoid = new LukasiewiczSigmoid();
        public static Sigmoid sigmoid = new Sigmoid();
        public static Signum signum = new Signum();
        public static ReLu relu = new ReLu();
        public static LeakyReLu leakyRelu = new LeakyReLu();
        public static Tanh tanh = new Tanh();
        public static Exponentiation exponentiation = new Exponentiation();
        public static SquareRoot sqrt = new SquareRoot();
        public static Inverse inverse = new Inverse();
        public static Reverse reverse = new Reverse();
        public static Logarithm logarithm = new Logarithm();
    }
}

