Search code examples
tensorflowtensorarray-broadcasting

TF broadcast along first axis


Say I have 2 tensors, one with shape (10,1) and another one with shape (10, 11, 1)... what I want is to multiply those broadcasting along the first axis, and not the last one, as used to

tf.zeros([10,1]) * tf.ones([10,12,1])

however this is not working... is there a way to do it without transposing it using perm?


Solution

  • For the above example, you need to do tf.zeros([10,1])[...,None] * tf.ones([10,12,1]) to satisfy broadcasting rules: https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules

    If you want to do this for any random shapes, you can do the multiplication with the transposed shape, so that the last dimensions of both the matrices match, obeying broadcasting rule and then do the transpose again, to get back to the required output,

    tf.transpose(a*tf.transpose(b))
    

    Example,

    a = tf.ones([10,])
    b = tf.ones([10,11,12,13,1])
    
    tf.transpose(b)
    #[1, 13, 12, 11, 10]
    (a*tf.transpose(b)) 
    #[1, 13, 12, 11, 10]
    tf.transpose(a*tf.transpose(b)) #Note a is [10,] not [10,1], otherwise you need to add transpose to a as well.
    #[10, 11, 12, 13, 1]
    

    Another approach is to expanding the axis:

    a = tf.ones([10])[(...,) + (tf.rank(b)-1) * (tf.newaxis,)]