Search code examples
pytorchreshapetensor

How to divide a 2-D tensors into smaller blocks using pytorch?


I've downloaded the EMNIST dataset of letters, and I've converted each image to a torch.tensor object with shape torch.size([28, 28]). However, I would like divide the 28*28 image into 7*7 blocks, with each block sized 16.

I.E. if the image pixel labeled from left to right, from up to down with 1, 2, ..., 784

[
  [1, 2, 3, ..., 28],
  ...
  [        ..., 784]
]

I expect the output to be a torch.tensor object of size torch.size([7, 7, 16])

[
  [
    [1, 2, 3, 4, 29, 30, 31, 32, 57, 58, 59, 60, 85, 86, 87, 88],
    ...
    [25, 26, 27, 28, 53, 54, 55, 56, 81, 82, 83, 84, 109, 110, 111, 112]
  ],
  ...
  [
    ...
    [697, 698, 699, 700, 725, 726, 727, 728, 753, 754, 755, 756, 781, 782, 783, 784]
  ]
]

I've tried to use torch.view(7, 7, 16), but it did not show up as the expected outcome.

Thanks a lot ^_^


Solution

  • Native Pytorch: Use torch.nn.functional.unfold. It is way faster, can be differentiable if needed and you can even take the patches that overlap. Here is an example:

    x = torch.arange(1,28*28+1).view(28,28).float() # unfold only works with float tensor
    x = x.unsqueeze(0).unsqueeze(0) # 2 unsqueeze to make `x` have dim 4 (BxCxHxW) 
    out = torch.nn.functional.unfold(x, kernel_size= 4, dilation= 1, padding= 0, stride= 4) 
    out.permute(0,2,1).shape  # torch.Size([1, 49, 16])
    

    Alternative: Another way is using einops which is way more flexible, but it would require you to install an extra package. And yeah, use einops if you want to work with multiple framework between pytorch, numpy, tensorflow, ...

    from einops import rearrange
    x = torch.arange(1,28*28+1).view(28,28)
    out = rearrange(x, "(h t1) (w t2) -> h w (t1 t2)", t1 = 4, t2 = 4)
    out.shape # torch.Size([7, 7, 16])
    

    Alternative 2: Just to be clear you can do it with pure tensor operator (no need for functional operator), for any reason:

    x = torch.arange(1,28*28+1).view(28,28)
    
    out = torch.stack(torch.split(torch.stack(torch.split(x,4, 0), -1), 4, 1), -2).view(16,49).permute(1,0)
    # or
    out = torch.stack(torch.split(torch.stack(torch.split(x,4, 1), -1), 4, 0), -1).view(16,49).permute(1,0)
    out.shape # torch.Size([49, 16])