/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.algebra.functions;

import cz.cvut.fel.ida.algebra.functions.ActivationFcn;
import cz.cvut.fel.ida.algebra.functions.Combination;
import cz.cvut.fel.ida.algebra.functions.ElementWise;
import cz.cvut.fel.ida.algebra.functions.combination.Softmax;
import cz.cvut.fel.ida.algebra.functions.combination.Sparsemax;
import cz.cvut.fel.ida.algebra.functions.transformation.joint.ConstantOne;
import cz.cvut.fel.ida.algebra.functions.transformation.joint.Identity;
import cz.cvut.fel.ida.algebra.functions.transformation.joint.Normalization;
import cz.cvut.fel.ida.algebra.functions.transformation.joint.Transposition;
import cz.cvut.fel.ida.algebra.values.Value;
import cz.cvut.fel.ida.setup.Settings;
import cz.cvut.fel.ida.utils.exporting.Exportable;
import java.util.List;
import java.util.logging.Logger;

public interface Transformation
extends ActivationFcn,
Exportable {
    public static final Logger LOG = Logger.getLogger(Transformation.class.getName());

    public Value evaluate(Value var1);

    public Value differentiate(Value var1);

    default public boolean changesShape() {
        return false;
    }

    public static Transformation getFunction(Settings.TransformationFcn transformation) {
        ElementWise function = ElementWise.getFunction(transformation);
        if (function != null) {
            return function;
        }
        switch (transformation) {
            case IDENTITY: {
                return Singletons.identity;
            }
            case TRANSP: {
                return Singletons.transposition;
            }
            case NORM: {
                return Singletons.normalization;
            }
            case SOFTMAX: {
                return Singletons.softmax;
            }
            case SPARSEMAX: {
                return Singletons.sparsemax;
            }
        }
        LOG.severe("Unimplemented Transformation function");
        return null;
    }

    public static abstract class State
    implements ActivationFcn.State {
        protected Transformation transformation;
        protected Value input;
        protected Value processedGradient;

        public State(Transformation transformation) {
            this.transformation = transformation;
        }

        @Override
        public void cumulate(Value value) {
            if (this.input != null) {
                LOG.severe("Resetting input in Transformation.State (this should probably be Combination.State instead!)");
            }
            this.input = value;
        }

        @Override
        public void invalidate() {
            this.input = null;
            this.processedGradient = null;
        }

        @Override
        public Value evaluate() {
            return this.transformation.evaluate(this.input);
        }

        @Override
        public Value initEval(List<Value> inputValues) {
            if (inputValues.size() != 1) {
                LOG.severe("Setting up Transformation.State with more than one Value.");
            }
            this.input = inputValues.get(0);
            return this.evaluate();
        }

        public Value gradient() {
            return this.transformation.differentiate(this.input);
        }

        @Override
        public void ingestTopGradient(Value topGradient) {
            Value inputFcnDerivative = this.gradient();
            this.processedGradient = inputFcnDerivative.times(topGradient);
        }

        @Override
        public Value nextInputGradient() {
            return this.processedGradient;
        }

        @Override
        public Transformation getTransformation() {
            return this.transformation;
        }

        @Override
        public ActivationFcn.State changeTransformationState(Transformation transformation) {
            if (transformation.getClass().equals(this.transformation.getClass())) {
                return this;
            }
            return transformation.getState(true);
        }

        @Override
        public Combination getCombination() {
            return null;
        }
    }

    public static class Singletons {
        public static Softmax softmax = new Softmax();
        public static Sparsemax sparsemax = new Sparsemax();
        public static Normalization normalization = new Normalization();
        public static Transposition transposition = new Transposition();
        public static Identity identity = new Identity();
        public static ConstantOne constantOne = new ConstantOne();
    }
}

