I want to make a reusable RANDOM TENSOR x
and assign the SAME tensor to VARIABLE y
. That means they should have the same value during Session.run()
.
But it turns out not the case. So why does y
NOT equal x
?
Update:
After applying sess.run(x)
and sess.run(y)
multiple times in line, confirmed that x
changes every time while y
stays steady. Why?
import tensorflow as tf
x = tf.random_normal([3], seed = 1)
y = tf.Variable(initial_value = x) # expect y get the same random tensor as x
diff = tf.subtract(x, y)
avg = tf.reduce_mean(diff)
sess = tf.InteractiveSession()
sess.run(y.initializer)
print('x0:', sess.run(x))
print('y0:', sess.run(y))
print('x1:', sess.run(x))
print('y1:', sess.run(y))
print('x2:', sess.run(x))
print('y2:', sess.run(y))
print('diff:', sess.run(diff))
print('avg:', sess.run(avg)) # expected as 0.0
sess.close()
Ouputs: TENSOR x changes every sess.run(x)
x0: [ 0.55171245 -0.13107552 -0.04481386]
y0: [-0.8113182 1.4845988 0.06532937]
x1: [-0.67590594 0.28665832 0.3215887 ]
y1: [-0.8113182 1.4845988 0.06532937]
x2: [1.2409041 0.44875884 0.33140722]
y2: [-0.8113182 1.4845988 0.06532937]
diff: [ 1.2404865 -1.4525002 0.05412297]
avg: -0.04116
The true cause is that:
x = tf.random_normal(seed = initial_seed)
is evolving every time when applying sess.run()
but produces the same tensor series x0-x1-x2
if restart running the script. Here provides some explanation on random seed.
To guarantee the same x
after every first run, we need reinitialize it. Not sure there is a decent way for my case. But we can set x
as a variable and initialize with a fixed seed. Either tf.get_variable
or tf.Variable
is OK. I find this answer fit my question.
Here is my final code. It works.
import tensorflow as tf
initializer = tf.random_normal_initializer(seed = 1)
x = tf.get_variable(name = 'x', shape = [3], dtype = tf.float32, initializer = initializer)
y = tf.Variable(initial_value = x)
diff = tf.subtract(x, y)
avg = tf.reduce_mean(diff)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
print('x0:', sess.run(x))
print('y0:', sess.run(y))
print('x1:', sess.run(x))
print('y1:', sess.run(y))
print('x2:', sess.run(x))
print('y2:', sess.run(y))
print('diff:', sess.run(diff))
print('avg:', sess.run(avg))
sess.close()
x0: [-0.8113182 1.4845988 0.06532937]
y0: [-0.8113182 1.4845988 0.06532937]
x1: [-0.8113182 1.4845988 0.06532937]
y1: [-0.8113182 1.4845988 0.06532937]
x2: [-0.8113182 1.4845988 0.06532937]
y2: [-0.8113182 1.4845988 0.06532937]
diff: [0. 0. 0.]
avg: 0.0