"""batched_map_fn_gradient_problem.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1IEsQAM_AU2H0bfiOfxdphk5dyi9kmS8g

# Gradient computation for `vectorized_map` nested inside `while_loop`

__Aim__: we wish to compute the gradients of a `tf.vectorized_map`ped function inside a `tf.while_loop`.

__Issue__: under XLA compilation, an `InvalidArgument` exception is raised if gradients are computed.

NB. the following code is a trivial example, serving as an MRE
"""

import tensorflow as tf


def vecadd(beta, gamma):
    """Auto-vectorized computation"""
    agg = tf.vectorized_map(lambda _: beta + gamma, tf.range(42))
    return tf.reduce_sum(agg)


def mapit(beta, gamma):
    """While-loop based application of vecadd"""
    sz = 2
    _, accum = tf.while_loop(
        cond=lambda i, _: i < sz,
        body=lambda i, accum: (i + 1, accum + vecadd(beta, gamma)),
        loop_vars=(0, 0.0),
        maximum_iterations=sz,
    )
    return accum


def value_and_grads(beta, gamma):
    """Compute the value and gradients"""
    beta = tf.convert_to_tensor(beta)
    gamma = tf.convert_to_tensor(gamma)

    with tf.GradientTape() as tape:
        tape.watch(beta)
        tape.watch(gamma)
        value = mapit(beta, gamma)
    grads = tape.gradient(value, [beta, gamma])
    return value, grads


print("Eager mode:", value_and_grads(0.1, 0.4))  # Eager mode runs
print(
    "Graph mode:", tf.function(lambda: value_and_grads(0.1, 0.4))()
)  # Graph mode runs
print(
    "XLA mode:",
    tf.function(lambda: value_and_grads(0.1, 0.4), jit_compile=True)(),
)  # XLA fails
