/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.algebra.functions.combination;

import cz.cvut.fel.ida.algebra.functions.ActivationFcn;
import cz.cvut.fel.ida.algebra.functions.Combination;
import cz.cvut.fel.ida.algebra.functions.Transformation;
import cz.cvut.fel.ida.algebra.values.ScalarValue;
import cz.cvut.fel.ida.algebra.values.Value;
import cz.cvut.fel.ida.algebra.values.VectorValue;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.logging.Logger;
import org.jetbrains.annotations.NotNull;

public class CosineSim
implements Combination {
    private static final Logger LOG = Logger.getLogger(CosineSim.class.getName());

    @Override
    public Combination replaceWithSingleton() {
        return Combination.Singletons.cosineSim;
    }

    @Override
    public Value evaluate(List<Value> inputs) {
        if (inputs.size() == 2) {
            Value x = inputs.get(0);
            Value y = inputs.get(1);
            if (!(x instanceof VectorValue) || !(y instanceof VectorValue)) {
                LOG.severe("Can only calculate cosine similarity between vectors!");
            }
            return new ScalarValue(this.cosine(x, y));
        }
        LOG.severe("Cannot calculate cosine similarity from more than 2 inputs");
        return null;
    }

    @NotNull
    public double cosine(Value x, Value y) {
        Iterator i1 = x.iterator();
        Iterator i2 = y.iterator();
        double dots = 0.0;
        double lenX = 0.0;
        double lenY = 0.0;
        while (i1.hasNext() && i2.hasNext()) {
            double a = (Double)i1.next();
            double b = (Double)i2.next();
            dots += a * b;
            lenX += a * a;
            lenY += b * b;
        }
        return dots / (Math.sqrt(lenX) * Math.sqrt(lenY));
    }

    @Override
    public boolean isPermutationInvariant() {
        return false;
    }

    @Override
    public boolean isComplex() {
        return true;
    }

    @Override
    public Transformation singleInputVersion() {
        LOG.severe("Trying to setup CosineSimilarity for a single input.");
        return Transformation.Singletons.identity;
    }

    @Override
    public ActivationFcn.State getState(boolean singleInput) {
        return new State(Combination.Singletons.cosineSim);
    }

    public static class State
    extends Combination.InputArrayState {
        List<Value> inputGradients;

        public State(Combination combination) {
            super(combination);
        }

        @Override
        public void invalidate() {
            super.invalidate();
            this.inputGradients.clear();
        }

        @Override
        public Value evaluate() {
            return Combination.Singletons.cosineSim.evaluate(this.accumulatedInputs);
        }

        @Override
        public Value initEval(List<Value> inputValues) {
            if (inputValues.size() != 2) {
                LOG.severe("Trying to evaluate CosineSim with other than 2 input values.");
            }
            this.inputGradients = new ArrayList<Value>(2);
            return super.initEval(inputValues);
        }

        @Override
        public void ingestTopGradient(Value topGradient) {
            this.processedGradient = topGradient;
            this.inputGradients = this.differentiate(this.accumulatedInputs);
        }

        @Override
        public Value nextInputGradient() {
            return this.processedGradient.elementTimes(this.inputGradients.get(this.i++));
        }

        public List<Value> differentiate(List<Value> inputs) {
            if (inputs.size() == 2) {
                Value x = inputs.get(0);
                Value y = inputs.get(1);
                Iterator i1 = x.iterator();
                Iterator i2 = y.iterator();
                double dots = 0.0;
                double lenX = 0.0;
                double lenY = 0.0;
                int count = 0;
                while (i1.hasNext() && i2.hasNext()) {
                    double a = (Double)i1.next();
                    double b = (Double)i2.next();
                    dots += a * b;
                    lenX += a * a;
                    lenY += b * b;
                    ++count;
                }
                double lens = Math.sqrt(lenX) * Math.sqrt(lenY);
                double cosine = dots / lens;
                Value diffX = x.getForm();
                Value diffY = y.getForm();
                for (int i = 0; i < count; ++i) {
                    double a = y.get(i) / lens - cosine * x.get(i) / lenX;
                    double b = x.get(i) / lens - cosine * y.get(i) / lenY;
                    diffX.set(i, a);
                    diffY.set(i, b);
                }
                this.inputGradients.clear();
                this.inputGradients.add(diffX);
                this.inputGradients.add(diffY);
                return this.inputGradients;
            }
            LOG.severe("Cannot calculate cosine similarity from more than 2 inputs");
            return null;
        }
    }
}

