For example,
# Batch = 5, a = 25, b = 2
# tensor t1 shape: (Batch, a, b)
# tensor t2 shape: (Batch, b)
# tensor res shape: (Batch, a)
print(t1)
<tf.Tensor: id=466, shape=(2, 25, 2), dtype=int32, numpy=
array([[[ 1, 26],
[ 2, 27],
[ 3, 28],
[ 4, 29],
[ 5, 30],
[ 6, 31],
[ 7, 32],
[ 8, 33],
[ 9, 34],
[10, 35],
[11, 36],
[12, 37],
[13, 38],
[14, 39],
[15, 40],
[16, 41],
[17, 42],
[18, 43],
[19, 44],
[20, 45],
[21, 46],
[22, 47],
[23, 48],
[24, 49],
[25, 50]],
[[ 1, 26],
[ 2, 27],
[ 3, 28],
[ 4, 29],
[ 5, 30],
[ 6, 31],
[ 7, 32],
[ 8, 33],
[ 9, 34],
[10, 35],
[11, 36],
[12, 37],
[13, 38],
[14, 39],
[15, 40],
[16, 41],
[17, 42],
[18, 43],
[19, 44],
[20, 45],
[21, 46],
[22, 47],
[23, 48],
[24, 49],
[25, 50]]], dtype=int32)>
print(t2)
<tf.Tensor: id=410, shape=(2, 2), dtype=int32, numpy=
array([[1, 0],
[1, 0]], dtype=int32)>
# after matrix multiplication
print(res)
<tf.Tensor: id=474, shape=(2, 25), dtype=int32, numpy=
array([[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25],
[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25]], dtype=int32)>
The way I think is to use matrix multiplication to only keep part like before, but it is hard for me to implement it.
If not mind could anyone help me?
Starting with
import tensorflow as tf # 2.6.0
Batch, a, b = 5, 25, 2
t1 = tf.random.normal((Batch, a, b))
t2 = tf.random.normal((Batch, b));
If I understand correctly you want t3[b,i] = sum(t2[b,i,j], t2[b,j])
this can be described in a straight forward manner using Einstein summation
t3 = tf.einsum('bij,bj->bi', t1, t2)
assert t3.shape == (Batch, a)
This could also be written using reduce_sum as
t3 = tf.reduce_sum(t1 * t2[:,None,:], axis=2)
assert t3.shape == (Batch, a)
Or if you want to do it with the matrix multiplication operator it is possible as well, but in that case you will have to unsqueeze t2 to (Batch, b, 1)
(t2[:,:,None]), and squeeze the result back from (Batch, b, 1)
to (Batch, b)
(t3[:,:,0]).
t3 = t1 @ t2[:,:,None]
assert t3[:,:,0].shape == (Batch, a)