# Copyright 2018 The TensorFlow Probability Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Random variable."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import functools
import tensorflow as tf

from tensorflow.python.client import session as tf_session
from tensorflow.python.framework import ops

__all__ = [
    "RandomVariable",
]


def _operator(attr):
  """Defers an operator overload to `attr`.

  Args:
    attr: Operator attribute to use.

  Returns:
    Function calling operator attribute.
  """
  @functools.wraps(attr)
  def func(a, *args):
    return attr(a.value, *args)
  return func


class RandomVariable(object):
  """Class for random variables.

  `RandomVariable` encapsulates properties of a random variable, namely, its
  distribution, sample shape, and (optionally overridden) value. Its `value`
  property is a `tf.Tensor`, which embeds the `RandomVariable` object into the
  TensorFlow graph. `RandomVariable` also features operator overloading and
  registration to TensorFlow sessions, enabling idiomatic usage as if one were
  operating on `tf.Tensor`s.

  The random variable's shape is given by

  `sample_shape + distribution.batch_shape + distribution.event_shape`,

  where `sample_shape` is an optional argument describing the shape of
  independent, identical draws from the distribution (default is `()`, meaning
  a single draw); `distribution.batch_shape` describes the shape of
  independent-but-not-identical draws (determined by the shape of the
  distribution's parameters); and `distribution.event_shape` describes the
  shape of dependent dimensions (e.g., `Normal` has scalar `event_shape`;
  `Dirichlet` has vector `event_shape`).

  #### Examples

  ```python
  from tensorflow_probability import edward2 as ed
  tfd = tf.contrib.distributions

  z1 = tf.constant([[1.0, -0.8], [0.3, -1.0]])
  z2 = tf.constant([[0.9, 0.2], [2.0, -0.1]])
  x = ed.RandomVariable(tfd.Bernoulli(logits=tf.matmul(z1, z2)))

  loc = ed.RandomVariable(tfd.Normal(0., 1.))
  x = ed.RandomVariable(tfd.Normal(loc, 1.), sample_shape=50)
  assert x.shape.as_list() == [50]
  assert x.sample_shape.as_list() == [50]
  assert x.distribution.batch_shape.as_list() == []
  assert x.distribution.event_shape.as_list() == []
  ```
  """

  def __init__(self,
               distribution,
               sample_shape=(),
               value=None):
    """Create a new random variable.

    Args:
      distribution: tf.Distribution governing the distribution of the random
        variable, such as sampling and log-probabilities.
      sample_shape: tf.TensorShape of samples to draw from the random variable.
        Default is `()` corresponding to a single sample.
      value: Fixed tf.Tensor to associate with random variable. Must have shape
        `sample_shape + distribution.batch_shape + distribution.event_shape`.
        Default is to sample from random variable according to `sample_shape`.

    Raises:
      ValueError: `value` has incompatible shape with
        `sample_shape + distribution.batch_shape + distribution.event_shape`.
      NotImplementedError: `distribution` does not have a `sample` method.
    """
    self._distribution = distribution

    self._sample_shape = tf.TensorShape(sample_shape)
    if value is not None:
      t_value = tf.convert_to_tensor(value, self.distribution.dtype)
      value_shape = t_value.shape
      expected_shape = self._sample_shape.concatenate(
          self.distribution.batch_shape).concatenate(
              self.distribution.event_shape)
      if not value_shape.is_compatible_with(expected_shape):
        raise ValueError(
            "Incompatible shape for initialization argument 'value'. "
            "Expected %s, got %s." % (expected_shape, value_shape))
      else:
        self._value = t_value
    else:
      try:
        self._value = self.distribution.sample(self._sample_shape)
      except NotImplementedError:
        raise NotImplementedError(
            "sample is not implemented for {0}. You must either pass in the "
            "value argument or implement sample for {0}."
            .format(self.__class__.__name__))

  @property
  def distribution(self):
    """Distribution of random variable."""
    return self._distribution

  @property
  def dtype(self):
    """`Dtype` of elements in this random variable."""
    return self.value.dtype

  @property
  def sample_shape(self):
    """Sample shape of random variable."""
    return self._sample_shape

  @property
  def shape(self):
    """Shape of random variable."""
    return self.value.shape

  @property
  def value(self):
    """Get tensor that the random variable corresponds to."""
    return self._value

  def __str__(self):
    if not isinstance(self.value, ops.EagerTensor):
      name = self.distribution.name
    else:
      name = _numpy_text(self.value)
    return "RandomVariable(\"%s\"%s%s%s)" % (
        name,
        ", shape=%s" % self.shape if self.shape.ndims is not None else "",
        ", dtype=%s" % self.dtype.name if self.dtype else "",
        ", device=%s" % self.value.device if self.value.device else "")

  def __repr__(self):
    string = "ed.RandomVariable '%s' shape=%s dtype=%s" % (
        self.distribution.name, self.shape, self.dtype.name)
    if hasattr(self.value, "numpy"):
      string += " numpy=%s" % _numpy_text(self.value, is_repr=True)
    return "<%s>" % string

  # Overload operators following tf.Tensor.
  __add__ = _operator(tf.Tensor.__add__)
  __radd__ = _operator(tf.Tensor.__radd__)
  __sub__ = _operator(tf.Tensor.__sub__)
  __rsub__ = _operator(tf.Tensor.__rsub__)
  __mul__ = _operator(tf.Tensor.__mul__)
  __rmul__ = _operator(tf.Tensor.__rmul__)
  __div__ = _operator(tf.Tensor.__div__)
  __rdiv__ = _operator(tf.Tensor.__rdiv__)
  __truediv__ = _operator(tf.Tensor.__truediv__)
  __rtruediv__ = _operator(tf.Tensor.__rtruediv__)
  __floordiv__ = _operator(tf.Tensor.__floordiv__)
  __rfloordiv__ = _operator(tf.Tensor.__rfloordiv__)
  __mod__ = _operator(tf.Tensor.__mod__)
  __rmod__ = _operator(tf.Tensor.__rmod__)
  __lt__ = _operator(tf.Tensor.__lt__)
  __le__ = _operator(tf.Tensor.__le__)
  __gt__ = _operator(tf.Tensor.__gt__)
  __ge__ = _operator(tf.Tensor.__ge__)
  __and__ = _operator(tf.Tensor.__and__)
  __rand__ = _operator(tf.Tensor.__rand__)
  __or__ = _operator(tf.Tensor.__or__)
  __ror__ = _operator(tf.Tensor.__ror__)
  __xor__ = _operator(tf.Tensor.__xor__)
  __rxor__ = _operator(tf.Tensor.__rxor__)
  __getitem__ = _operator(tf.Tensor.__getitem__)
  __pow__ = _operator(tf.Tensor.__pow__)
  __rpow__ = _operator(tf.Tensor.__rpow__)
  __invert__ = _operator(tf.Tensor.__invert__)
  __neg__ = _operator(tf.Tensor.__neg__)
  __abs__ = _operator(tf.Tensor.__abs__)
  __matmul__ = _operator(tf.Tensor.__matmul__)
  __rmatmul__ = _operator(tf.Tensor.__rmatmul__)
  __iter__ = _operator(tf.Tensor.__iter__)
  __bool__ = _operator(tf.Tensor.__bool__)
  __nonzero__ = _operator(tf.Tensor.__nonzero__)

  def __hash__(self):
    return id(self)

  def __eq__(self, other):
    return id(self) == id(other)

  def __ne__(self, other):
    return not self == other

  def eval(self, session=None, feed_dict=None):
    """In a session, computes and returns the value of this random variable.

    This is not a graph construction method, it does not add ops to the graph.

    This convenience method requires a session where the graph
    containing this variable has been launched. If no session is
    passed, the default session is used.

    Args:
      session: tf.BaseSession.
        The `tf.Session` to use to evaluate this random variable. If
        none, the default session is used.
      feed_dict: dict.
        A dictionary that maps `tf.Tensor` objects to feed values. See
        `tf.Session.run()` for a description of the valid feed values.

    Returns:
      Value of the random variable.

    #### Examples

    ```python
    x = Normal(0.0, 1.0)
    with tf.Session() as sess:
      # Usage passing the session explicitly.
      print(x.eval(sess))
      # Usage with the default session.  The 'with' block
      # above makes 'sess' the default session.
      print(x.eval())
    ```
    """
    return self.value.eval(session=session, feed_dict=feed_dict)

  def numpy(self):
    """Value as NumPy array, only available for TF Eager."""
    if not isinstance(self.value, ops.EagerTensor):
      raise NotImplementedError("value argument must be a EagerTensor.")

    return self.value.numpy()

  def get_shape(self):
    """Get shape of random variable."""
    return self.shape

  # This enables the RandomVariable's overloaded "right" binary operators to
  # run when the left operand is an ndarray, because it accords the
  # RandomVariable class higher priority than an ndarray, or a numpy matrix.
  __array_priority__ = 100


def _numpy_text(tensor, is_repr=False):
  """Human-readable representation of a tensor's numpy value."""
  if tensor.dtype.is_numpy_compatible:
    text = repr(tensor.numpy()) if is_repr else str(tensor.numpy())
  else:
    text = "<unprintable>"
  if "\n" in text:
    text = "\n" + text
  return text


def _session_run_conversion_fetch_function(tensor):
  return ([tensor.value], lambda val: val[0])


def _session_run_conversion_feed_function(feed, feed_val):
  return [(feed.value, feed_val)]


def _session_run_conversion_feed_function_for_partial_run(feed):
  return [feed.value]


def _tensor_conversion_function(v, dtype=None, name=None, as_ref=False):
  del name, as_ref  # unused
  if dtype and not dtype.is_compatible_with(v.dtype):
    raise ValueError(
        "Incompatible type conversion requested to type '%s' for variable "
        "of type '%s'" % (dtype.name, v.dtype.name))
  return v.value


tf_session.register_session_run_conversion_functions(  # enable sess.run, eval
    RandomVariable,
    _session_run_conversion_fetch_function,
    _session_run_conversion_feed_function,
    _session_run_conversion_feed_function_for_partial_run)

tf.register_tensor_conversion_function(  # enable tf.convert_to_tensor
    RandomVariable, _tensor_conversion_function)
