Search code examples
pythontensorflowtensorflow-datasetseager-execution

Dataset's from_generator fails with a "Only integers are valid indices" in eager mode only, no error in graph mode


I am creating a tensorflow Dataset using the from_generator function. In graph/session mode, it works fine:

import tensorflow as tf

x = {str(i): i for i in range(10)}

def gen():
  for i in x:
    yield x[i]

ds = tf.data.Dataset.from_generator(gen, tf.int32)
batch = ds.make_one_shot_iterator().get_next()

with tf.Session() as sess:
  while True:
    try:
      print(sess.run(batch), end=' ')
    except tf.errors.OutOfRangeError:
      break
# 0 1 2 3 4 5 6 7 8 9

Suprisingly however, it fails using eager execution:

import tensorflow as tf
tf.enable_eager_execution()

x = {str(i): i for i in range(10)}

def gen():
  for i in x:
    yield x[i]

ds = tf.data.Dataset.from_generator(gen, tf.int32)

for x in ds:
  print(x, end=' ')
# TypeError: Only integers, slices (`:`), ellipsis (`...`), tf.newaxis (`None`) and scalar tf.int32/tf.int64 tensors are valid indices, got '1'

I was assuming that, since the body of the generator is pure python that does not get serialized, tensorflow would not look into -- indeed not care what is in -- the generator. But this is apparently not the case. So why does tensorflow care about what's inside the generator? Assuming the generator cannot be changed, is there a way to somehow work around this problem?


Solution

  • tl;dr The issue is unrelated to TensorFlow. Your loop variable shadows previously defined x.

    Fact 1: for loop in Python does not have a namespace and leaks loop variables into the surrounding namespace (globals() in your example).

    Fact 2: Closures are "dynamic" i.e. the gen generator only knows it should lookup the name "x" to evaluate x[i]. The actual value of x will be resolved when the generator is iterated over.

    Putting these two together and unrolling the first two iterations of the for loop we get the following execution sequence:

    ds = tf.data.Dataset.from_generator(gen, tf.int32)
    it = iter(ds)
    x = next(it)  # Calls to the generator which yields back x[i].
    print(x, end='')
    # Calls to the generator as before, but x is no longer a dict so x[i]
    # is actually indexing into a Tensor!
    x = next(it)  
    

    The fix is simple: use a different loop variable name.

    for item in ds:
      print(item, end=' ')