Search code examples
pythontensorflowmatrix-multiplication

Element-wise multiplication of matrices in Tensorflow : how to avoid for loop


I want to do the following multiplication in tensorflow (TF 2.10), but I'm not sure how to.

I have an image tensor a, which is of shape 224x224x3 and a tensor b, which is of shape 224x224xf. I want to multiply (element-wise) a by each 2D matrix of b sliced by f to get a matrix c of shape 224x224xf.

So for example, the 1st multiplication would be done as follows:

tf.reduce_sum(a * b[:,:,0][:,:,None],axis=-1)

(broadcasting + summation, result is shape 224x224)

and so on until the fth multiplication. Result would be the aggregation of f matrices of shape 224x224 in c matrix of shape 224x224xf.

I would greatly appreciate help on how to do this using tensorflow functionality.

EDIT: I realize that what I want to do is equivalent to a Conv2D operation with kernel_size=1 and filters=f. Maybe it can help.


Solution

  • You could multiply each channel of a with b and then sum:

    X = a[:,:,0:1] * b + a[:,:,1:2] * b + a[:,:,2:3] * b
    

    The shape of X is (224, 224, f) and it will give the same results as your multiplications:

    (X[:, :, 0] == tf.reduce_sum(a * b[:, :, 0][:, :, None], axis=-1)).numpy().all()
    

    Output:

    True
    

    The following gives slightly different results, I guess because of floating point rounding:

    tf.reduce_sum(a, axis=-1, keepdims=True) * b