/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.algebra.functions.transformation.joint;

import cz.cvut.fel.ida.algebra.functions.ActivationFcn;
import cz.cvut.fel.ida.algebra.functions.Transformation;
import cz.cvut.fel.ida.algebra.values.MatrixValue;
import cz.cvut.fel.ida.algebra.values.ScalarValue;
import cz.cvut.fel.ida.algebra.values.Value;
import cz.cvut.fel.ida.algebra.values.VectorValue;
import java.util.Arrays;
import java.util.logging.Logger;

public class Reshape
implements Transformation {
    private static final Logger LOG = Logger.getLogger(Reshape.class.getName());
    private final int[] shape;

    public Reshape(int[] shape) {
        if (shape == null) {
            shape = new int[]{0};
        }
        if (shape.length > 2) {
            String err = "Unsupported shape: " + Arrays.toString(shape) + ". Expected max two elements.";
            LOG.severe(err);
            throw new ArithmeticException(err);
        }
        this.shape = shape;
    }

    @Override
    public ActivationFcn replaceWithSingleton() {
        return null;
    }

    @Override
    public Value evaluate(Value combinedInputs) {
        return combinedInputs.reshape(this.shape);
    }

    @Override
    public Value differentiate(Value combinedInputs) {
        return null;
    }

    @Override
    public ActivationFcn.State getState(boolean singleInput) {
        return new State(this);
    }

    @Override
    public boolean changesShape() {
        return true;
    }

    public static class State
    extends Transformation.State {
        private final Reshape reshape;

        public State(Reshape transformation) {
            super(transformation);
            this.reshape = transformation;
        }

        @Override
        public void invalidate() {
            super.invalidate();
        }

        @Override
        public Value evaluate() {
            return this.reshape.evaluate(this.input);
        }

        @Override
        public void ingestTopGradient(Value topGradient) {
            if (this.input instanceof ScalarValue) {
                this.processedGradient = topGradient.reshape(new int[]{0});
            } else if (this.input instanceof VectorValue) {
                VectorValue v = (VectorValue)this.input;
                int len = v.values.length;
                this.processedGradient = topGradient.reshape(new int[]{v.rowOrientation ? len : 0, v.rowOrientation ? 0 : len});
            } else if (this.input instanceof MatrixValue) {
                MatrixValue m = (MatrixValue)this.input;
                this.processedGradient = topGradient.reshape(new int[]{m.rows, m.cols});
            }
        }
    }
}

