Search code examples
pythontensorflowscanning

tensorflow scan on a matrix


How to get the following example of scan to work? I plan to use this example to test out some code in the function f.

def f(prev_y, curr_y):
  fval = tf.nn.softmax(curr_y)
  return fval

a = tf.constant([[.1, .25, .3, .2, .15],
                 [.07, .35, .27, .17, .14]])
c = tf.scan(f, a, initializer=0)
with tf.Session() as sess:
  print(sess.run(c))

Solution

  • Your initializer=0 is not valid. As documented:

    If an initializer is provided, then the output of fn must have the same structure as initializer; and the first argument of fn must match this structure.

    The output of your f has the same type and shape as curr_y, the second argument, which won't match 0. In this case, you would need:

    init = tf.constant([0., 0., 0., 0., 0.])
    c = tf.scan(f, a, initializer=init)
    

    For this specific case (your f specifically), you don't (and maybe shouldn't) need to use tf.scan because your f uses only one argument. tf.map_fn would do the job:

    c = tf.map_fn(tf.nn.softmax, a)