How to get the following example of scan to work? I plan to use this example to test out some code in the function f.
def f(prev_y, curr_y):
fval = tf.nn.softmax(curr_y)
return fval
a = tf.constant([[.1, .25, .3, .2, .15],
[.07, .35, .27, .17, .14]])
c = tf.scan(f, a, initializer=0)
with tf.Session() as sess:
print(sess.run(c))
Your initializer=0
is not valid. As documented:
If an
initializer
is provided, then the output of fn must have the same structure asinitializer
; and the first argument offn
must match this structure.
The output of your f
has the same type and shape as curr_y
, the second argument, which won't match 0
. In this case, you would need:
init = tf.constant([0., 0., 0., 0., 0.])
c = tf.scan(f, a, initializer=init)
For this specific case (your f
specifically), you don't (and maybe shouldn't) need to use tf.scan
because your f
uses only one argument. tf.map_fn
would do the job:
c = tf.map_fn(tf.nn.softmax, a)