I want to be able to call tensorflow's tf.math.unsorted_segment_max on a data tensor that is of size [N, s, K]. N is the number of channels and K is the number of filters/feature maps. s is the size of one-channel data sample. I have segment_ids in the size of s. For example, let's say my sample size is s=6, and that I want to do a max over two elements (as if doing the usual max pooling, so on the second, s-dimension of the whole data tensor). Then my segment_ids equals to [0,0,1,1,2,2].
I tried running
tf.math.unsorted_segment_max(data, segment_ids, num_segments)
with extended 0 and 2 dimensions for the segment_ids, but since the segment ids are then repeated, the result is of course of size [3] instead of [N,3,K] as I would like.
So my question is, how to construct a proper segment_ids tensor, to achieve what I want? I.e. to have segment max done based on the original s-sized segment_ids tensor, but in each dimension separately?
Basically, going back to the example, given the 1D segment id list seg_id=[0,0,1,1,2,2], I would like to construct something like a segment_ids tensor for which:
segment_ids[i,:,j] = seg_id + num_segments*(i*K + j)
So that when calling the tf.math.(unsorted_)segment_max with this tensor as segment ids, I will get a result of size [N, 3, K], with the same effect as if one would run the segment_max for each data[x,:,y] separately and stack the results appropriately.
Any way of doing this is okay, as long as it works with tensorflow. I would guess a combination of tf.tile, tf.reshape or tf.concat should do the trick but I can't figure out how, in what order. Also, is there a more straightforward way to do it? Without the need of adjusting the segment_ids during each "pooling" step?
I think you can achieve what you want with tf.nn.pool
:
import tensorflow as tf
with tf.Graph().as_default(), tf.Session() as sess:
data = tf.constant([
[
[ 1, 12, 13],
[ 2, 11, 14],
[ 3, 10, 15],
[ 4, 9, 16],
[ 5, 8, 17],
[ 6, 7, 18],
],
[
[19, 30, 31],
[20, 29, 32],
[21, 28, 33],
[22, 27, 34],
[23, 26, 35],
[24, 25, 36],
]], dtype=tf.int32)
segments = tf.constant([0, 0, 1, 1, 2, 2], dtype=tf.int32)
pool = tf.nn.pool(data, [2], 'MAX', 'VALID', strides=[2])
print(sess.run(pool))
Output:
[[[ 2 12 14]
[ 4 10 16]
[ 6 8 18]]
[[20 30 32]
[22 28 34]
[24 26 36]]]
If you really want to us tf.unsorted_segment_max
, you can do it as you suggest in your own answer. Here is an equivalent formulation that avoids transposing and includes the final reshaping:
import tensorflow as tf
with tf.Graph().as_default(), tf.Session() as sess:
data = ...
segments = ...
shape = tf.shape(data)
n, k = shape[0], shape[2]
m = tf.reduce_max(segments) + 1
grid = tf.meshgrid(tf.range(n) * m * k,
segments * k,
tf.range(k), indexing='ij')
segment_nd = tf.add_n(grid)
segmented = tf.unsorted_segment_max(data, segment_nd, n * m * k)
result = tf.reshape(segmented, [n, m, k])
print(sess.run(result))
# Same output
Both methods should work fine in a neural network in terms of back-propagation.
EDIT: In terms of performance, pooling seems to be more scalable than the segmented sum (as one would expect):
import tensorflow as tf
import numpy as np
def method_pool(data, window):
return tf.nn.pool(data, [window], 'MAX', 'VALID', strides=[window])
def method_segment(data, window):
shape = tf.shape(data)
n, s, k = shape[0], shape[1], shape[2]
segments = tf.range(s) // window
m = tf.reduce_max(segments) + 1
grid = tf.meshgrid(tf.range(n) * m * k,
segments * k,
tf.range(k), indexing='ij')
segment_nd = tf.add_n(grid)
segmented = tf.unsorted_segment_max(data, segment_nd, n * m * k)
return tf.reshape(segmented, [n, m, k])
np.random.seed(100)
rand_data = np.random.rand(300, 500, 100)
window = 10
with tf.Graph().as_default(), tf.Session() as sess:
data = tf.constant(rand_data, dtype=tf.float32)
res_pool = method_pool(data, n)
res_segment = method_segment(data, n)
print(np.allclose(*sess.run([res_pool, res_segment])))
# True
%timeit sess.run(res_pool)
# 2.56 ms ± 80.8 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit sess.run(res_segment)
# 514 ms ± 6.29 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)