Search code examples

How to implement tf.gather_nd in Pytorch with the argument batch_dims?

I have been doing a project on image matching, so I need to find correspondences between 2 images. To get descriptors, I will need a interpolate function. However, when I read about a equivalent function which is done in Tensorflow, I still don’t get how to implement tf.gather_nd(parmas, indices, barch_dims) in Pytorch. Especially when there is a argument: batch_dims. I have gone through stackoverflow and there is no perfect equivalence yet.

The referred interpolate function in Tensorflow is below and I have been trying to implement this in Pytorch Arguments' information is below:

inputs is a dense feature map[i] from a for loop of batch size, which means it is 3D[H, W, C](in pytorch is [C, H, W])

pos is a set of random point coordinate shapes like [[i, j], [i, j],...,[i, j]], so it is 2D when it goes in interpolate function(in pytorch is [[i,i,...,i], [j,j,...,j]])

and it then expands both of their dimensions when they get into this function

I just want a perfect implement of tf.gather_nd with argument batch_dims. Thank you! And here's a simple example of using it:

pos = tf.ones((12, 2)) ## stands for a set of coordinates [[i, i,…, i], [j, j,…, j]]
inputs = tf.ones((4, 4, 128)) ## stands for [H, W, C] of dense feature map
outputs = interpolate(pos, inputs, batched=False)
print(outputs.get_shape()) # We get (12, 128) here

interpolate function (tf version):

def interpolate(pos, inputs, nd=True):

    pos = tf.expand_dims(pos, 0)
    inputs = tf.expand_dims(inputs, 0)

    h = tf.shape(inputs)[1]
    w = tf.shape(inputs)[2]

    i = pos[:, :, 0]
    j = pos[:, :, 1]

    i_top_left = tf.clip_by_value(tf.cast(tf.math.floor(i), tf.int32), 0, h - 1)
    j_top_left = tf.clip_by_value(tf.cast(tf.math.floor(j), tf.int32), 0, w - 1)

    i_top_right = tf.clip_by_value(tf.cast(tf.math.floor(i), tf.int32), 0, h - 1)
    j_top_right = tf.clip_by_value(tf.cast(tf.math.ceil(j), tf.int32), 0, w - 1)

    i_bottom_left = tf.clip_by_value(tf.cast(tf.math.ceil(i), tf.int32), 0, h - 1)
    j_bottom_left = tf.clip_by_value(tf.cast(tf.math.floor(j), tf.int32), 0, w - 1)

    i_bottom_right = tf.clip_by_value(tf.cast(tf.math.ceil(i), tf.int32), 0, h - 1)
    j_bottom_right = tf.clip_by_value(tf.cast(tf.math.ceil(j), tf.int32), 0, w - 1)

    dist_i_top_left = i - tf.cast(i_top_left, tf.float32)
    dist_j_top_left = j - tf.cast(j_top_left, tf.float32)
    w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left)
    w_top_right = (1 - dist_i_top_left) * dist_j_top_left
    w_bottom_left = dist_i_top_left * (1 - dist_j_top_left)
    w_bottom_right = dist_i_top_left * dist_j_top_left

    if nd:
        w_top_left = w_top_left[..., None]
        w_top_right = w_top_right[..., None]
        w_bottom_left = w_bottom_left[..., None]
        w_bottom_right = w_bottom_right[..., None]

    interpolated_val = (
        w_top_left * tf.gather_nd(inputs, tf.stack([i_top_left, j_top_left], axis=-1), batch_dims=1) +
        w_top_right * tf.gather_nd(inputs, tf.stack([i_top_right, j_top_right], axis=-1), batch_dims=1) +
        w_bottom_left * tf.gather_nd(inputs, tf.stack([i_bottom_left, j_bottom_left], axis=-1), batch_dims=1) +
        w_bottom_right * tf.gather_nd(inputs, tf.stack([i_bottom_right, j_bottom_right], axis=-1), batch_dims=1)

    interpolated_val = tf.squeeze(interpolated_val, axis=0)
    return interpolated_val


  • As far as I'm aware there is no directly equivalent of tf.gather_nd in PyTorch and implementing a generic version with batch_dims is not that simple. However, you likely don't need a generic version, and given the context of your interpolate function, a version for [C, H, W] would suffice.

    At the beginning of interpolate you add a singular dimension to the front, which is the batch dimension. Setting batch_dims=1 in tf.gather_nd means there is one batch dimension at the beginning, therefore it applies it per batch, i.e. it indexes inputs[0] with pos[0] etc. There is no benefit of adding a singular batch dimension, because you could have just used the direct computation.

    # Adding singular batch dimension
    # Shape: [1, num_pos, 2]
    pos = tf.expand_dims(pos, 0)
    # Shape: [1, H, W, C]
    inputs = tf.expand_dims(inputs, 0)
    batched_result = tf.gather_nd(inputs, pos, batch_dims=1)
    single_result = tf.gater_nd(inputs[0], pos[0])
    # The first element in the batched result is the same as the single result
    # Hence there is no benefit to adding a singular batch dimension.
    tf.reduce_all(batched_result[0] == single_result) # => True

    Single version

    In PyTorch the implementation for [H, W, C] can be done with Python's indexing. While PyTorch usually uses [C, H, W] for images, it's only a matter of what dimension to index, but let's keep them the same as in TensorFlow for the sake of comparison. If you were to index them manually, you would do it as such: inputs[pos_h[0], pos_w[0]], inputs[pos_h[1], pos_w[1]] and so on. PyTorch allows you to do that automatically by providing the indices as lists: inputs[pos_h, pos_w], where pos_h and pos_w have the same length. All you need to do is split your pos into two separate tensors, one for the indices along the height dimension and the other along the width dimension, which you also did in the TensorFlow version.

    inputs = torch.randn(4, 4, 128)
    # Random positions 0-3, shape: [12, 2]
    pos = torch.randint(4, (12, 2))
    # Positions split by dimension
    pos_h = pos[:, 0]
    pos_w = pos[:, 1]
    # Index the inputs with the indices per dimension
    gathered = inputs[pos_h, pos_w]
    # Verify that it's identical to TensorFlow's output
    inputs_tf = tf.convert_to_tensor(inputs.numpy())
    pos_tf = tf.convert_to_tensor(pos.numpy())
    gathered_tf = tf.gather_nd(inputs_tf, pos_tf)
    gathered_tf = torch.from_numpy(gathered_tf.numpy())
    torch.equal(gathered_tf, gathered) # => True

    If you want to apply it to a tensor of size [C, H, W] instead, you only need to change the dimensions you want to index:

    # For [H, W, C]
    gathered = inputs[pos_h, pos_w]
    # For [C, H, W]
    gathered = inputs[:, pos_h, pos_w]

    Batched version

    Making it a batched batched version (for [N, H, W, C] or [N, C, H, W]) is not that difficult, and using that is more appropriate, since you're dealing with batches anyway. The only tricky part is that each element in the batch should only be applied to the corresponding batch. For this the batch dimensions needs to be enumerated, which can be done with torch.arange. The batch enumeration is just the list with the batch indices, which will be combined with the pos_h and pos_w indices, resulting in inputs[0, pos_h[0, 0], pos_h[0, 0]], inputs[0, pos_h[0, 1], pos_h[0, 1]] ... inputs[1, pos_h[1, 0], pos_h[1, 0]] etc.

    batch_size = 3
    inputs = torch.randn(batch_size, 4, 4, 128)
    # Random positions 0-3, different for each batch, shape: [3, 12, 2]
    pos = torch.randint(4, (batch_size, 12, 2))
    # Positions split by dimension
    pos_h = pos[:, :, 0]
    pos_w = pos[:, :, 1]
    batch_enumeration = torch.arange(batch_size) # => [0, 1, 2]
    # pos_h and pos_w have shape [3, 12], so the batch enumeration needs to be
    # repeated 12 times per batch.
    # Unsqueeze to get shape [3, 1], now the 1 could be repeated to 12, but
    # broadcasting will do that automatically.
    batch_enumeration = batch_enumeration.unsqueeze(1)
    # Index the inputs with the indices per dimension
    gathered = inputs[batch_enumeration, pos_h, pos_w]
    # Again, verify that it's identical to TensorFlow's output
    inputs_tf = tf.convert_to_tensor(inputs.numpy())
    pos_tf = tf.convert_to_tensor(pos.numpy())
    # This time with batch_dims=1
    gathered_tf = tf.gather_nd(inputs_tf, pos_tf, batch_dims=1)
    gathered_tf = torch.from_numpy(gathered_tf.numpy())
    torch.equal(gathered_tf, gathered) # => True

    Again, for [N, C, H, W], only the dimensions that are indexed need to be changed:

    # For [N, H, W, C]
    gathered = inputs[batch_enumeration, pos_h, pos_w]
    # For [N, C, H, W]
    gathered = inputs[batch_enumeration, :, pos_h, pos_w]

    Just a little side note on the interpolate implementation, rounding the positions (floor and ceil respectively) doesn't make sense, because indices must be integers, so it has no effect, as long as your positions are actual indices. That also results in i_top_left and i_bottom_left being the same value, but even if they are to be rounded differently, they are always 1 position apart. Furthermore, i_top_left and i_top_right are literally the same. I don't think that this function produces a meaningful output. I don't know what you're trying to achieve, but if you're looking for image interpolation you could have a look at torch.nn.functional.interpolate.