Search code examples
pythontensorflowmax-poolingspatial-pooling

How to do tensorflow segment_max in high dimension


I want to be able to call tensorflow's tf.math.unsorted_segment_max on a data tensor that is of size [N, s, K]. N is the number of channels and K is the number of filters/feature maps. s is the size of one-channel data sample. I have segment_ids in the size of s. For example, let's say my sample size is s=6, and that I want to do a max over two elements (as if doing the usual max pooling, so on the second, s-dimension of the whole data tensor). Then my segment_ids equals to [0,0,1,1,2,2].

I tried running

tf.math.unsorted_segment_max(data, segment_ids, num_segments)

with extended 0 and 2 dimensions for the segment_ids, but since the segment ids are then repeated, the result is of course of size [3] instead of [N,3,K] as I would like.

So my question is, how to construct a proper segment_ids tensor, to achieve what I want? I.e. to have segment max done based on the original s-sized segment_ids tensor, but in each dimension separately?

Basically, going back to the example, given the 1D segment id list seg_id=[0,0,1,1,2,2], I would like to construct something like a segment_ids tensor for which:

segment_ids[i,:,j] = seg_id + num_segments*(i*K + j) 

So that when calling the tf.math.(unsorted_)segment_max with this tensor as segment ids, I will get a result of size [N, 3, K], with the same effect as if one would run the segment_max for each data[x,:,y] separately and stack the results appropriately.

Any way of doing this is okay, as long as it works with tensorflow. I would guess a combination of tf.tile, tf.reshape or tf.concat should do the trick but I can't figure out how, in what order. Also, is there a more straightforward way to do it? Without the need of adjusting the segment_ids during each "pooling" step?


Solution

  • I think you can achieve what you want with tf.nn.pool:

    import tensorflow as tf
    
    with tf.Graph().as_default(), tf.Session() as sess:
        data = tf.constant([
            [
                [ 1, 12, 13],
                [ 2, 11, 14],
                [ 3, 10, 15],
                [ 4,  9, 16],
                [ 5,  8, 17],
                [ 6,  7, 18],
            ],
            [
                [19, 30, 31],
                [20, 29, 32],
                [21, 28, 33],
                [22, 27, 34],
                [23, 26, 35],
                [24, 25, 36],
            ]], dtype=tf.int32)
        segments = tf.constant([0, 0, 1, 1, 2, 2], dtype=tf.int32)
        pool = tf.nn.pool(data, [2], 'MAX', 'VALID', strides=[2])
        print(sess.run(pool))
    

    Output:

    [[[ 2 12 14]
      [ 4 10 16]
      [ 6  8 18]]
    
     [[20 30 32]
      [22 28 34]
      [24 26 36]]]
    

    If you really want to us tf.unsorted_segment_max, you can do it as you suggest in your own answer. Here is an equivalent formulation that avoids transposing and includes the final reshaping:

    import tensorflow as tf
    
    with tf.Graph().as_default(), tf.Session() as sess:
        data = ...
        segments = ...
        shape = tf.shape(data)
        n, k = shape[0], shape[2]
        m = tf.reduce_max(segments) + 1
        grid = tf.meshgrid(tf.range(n) * m * k,
                           segments * k,
                           tf.range(k), indexing='ij')
        segment_nd = tf.add_n(grid)
        segmented = tf.unsorted_segment_max(data, segment_nd, n * m * k)
        result = tf.reshape(segmented, [n, m, k])
        print(sess.run(result))
        # Same output
    

    Both methods should work fine in a neural network in terms of back-propagation.

    EDIT: In terms of performance, pooling seems to be more scalable than the segmented sum (as one would expect):

    import tensorflow as tf
    import numpy as np
    
    def method_pool(data, window):
        return tf.nn.pool(data, [window], 'MAX', 'VALID', strides=[window])
    
    def method_segment(data, window):
        shape = tf.shape(data)
        n, s, k = shape[0], shape[1], shape[2]
        segments = tf.range(s) // window
        m = tf.reduce_max(segments) + 1
        grid = tf.meshgrid(tf.range(n) * m * k,
                           segments * k,
                           tf.range(k), indexing='ij')
        segment_nd = tf.add_n(grid)
        segmented = tf.unsorted_segment_max(data, segment_nd, n * m * k)
        return tf.reshape(segmented, [n, m, k])
    
    np.random.seed(100)
    rand_data = np.random.rand(300, 500, 100)
    window = 10
    with tf.Graph().as_default(), tf.Session() as sess:
        data = tf.constant(rand_data, dtype=tf.float32)
        res_pool = method_pool(data, n)
        res_segment = method_segment(data, n)
        print(np.allclose(*sess.run([res_pool, res_segment])))
        # True
        %timeit sess.run(res_pool)
        # 2.56 ms ± 80.8 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
        %timeit sess.run(res_segment)
        # 514 ms ± 6.29 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)