there might be an obvious solution, but I haven't found it yet. I want to do a simple multiplication, where I have one tensor that gives me a kind of weight vector and another one that are stacked tensors (same number as weights). It seems straight forward using tf.tensordot
but that doesn't work for unknown batch sizes.
import collections
import tensorflow as tf
tf.reset_default_graph()
x = tf.placeholder(shape=(None, 4, 1), dtype=tf.float32, name='x')
y_true = tf.placeholder(shape=(None, 4, 1), dtype=tf.float32, name='y_true')
# These are the models that I want to combine
linear_model0 = tf.layers.Dense(units=1, name='linear_model0')
linear_model1 = tf.layers.Dense(units=1, name='linear_model1')
agents = collections.OrderedDict()
agents[0] = linear_model0(x) # shape (?,4,1)
agents[1] = linear_model1(x) # shape (?,4,1)
stacked = tf.stack(list(agents.values()), axis=1) # shape (?,2,4,1)
# This is the model that produces the weights
x_flat = tf.layers.Flatten()(x)
weight_model = tf.layers.Dense(units=2, name='weight_model')
weights = weight_model(x_flat) # shape: (?,2)
# This is the final output
y_pred = tf.tensordot(weights, stacked, axes = 2, name='y_pred')
# PROBLEM HERE: shape: (4,1) instead of (?,4,1)
# Running the whole thing
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
# Example1 (output of shape (1,4,1) expected, got (4,1))
print('model', sess.run(y_pred,
{x: [[[1], [2], [3], [4]]]}).shape)
# Example2 (output of (2,4,1) expected, got (4,1))
print('model', sess.run(y_pred,
{x: [[[1], [2], [3], [4]], [[1], [2], [3], [4]]]}).shape)
So, the multiplication works as expected for the first input, but does only that first one and not a batch of inputs. Any help?
Similar questions that didn't resolve my issue:
The tf.tensordot
is not suitable in the case because, based on your explanation, it necessary to set axis equal to 1 which cause the incompatibility in matrix sizes. One is [batch_size, 2]
the other is [batch_size, 8]
. On the other, If you set the axis to [[1],[1]]
it is not what you expected:
tf.tensordot(weights, stacks, axes=[[1],[1]]) # shape = (?,?,1,1)
How to fix the issue?
Use tf.ensim
as contraction between tensors of arbitrary dimension:
tf.einsum('ij,ijkl->ikl', weights, stacked)