Search code examples

How to split multi-dimensional arrays based on the unique indices of another array?

I have two torch tensors a and b:

import torch
torch.manual_seed(0) # for reproducibility

a = torch.rand(size = (5, 10, 1))
b = torch.tensor([3, 3, 1, 5, 3, 1, 0, 2, 1, 2])

I want to split the 2nd dimension of a (which is dim = 1 in the Python language) based on the unique values in b.

What I have tried so far:

# find the unique values and unique indices of b
unique_values, unique_indices = torch.unique(b, return_inverse = True)

# split a in where dim = 1, based on unique indices
l = torch.tensor_split(a, unique_indices, dim = 1)

I was expecting l to be a list of n number of tensors where n is the number of unique values in b. I was also expecting the tensors to have the shape (5, number of elements corresponding to unique_values, 1).

However, I get the following:






         [0.8579]]]), tensor([], size=(5, 0, 1)), tensor([], size=(5, 0, 1)), tensor([[[0.9971],




         [0.6870]]]), tensor([], size=(5, 0, 1)), tensor([], size=(5, 0, 1)), tensor([], size=(5, 0, 1)), tensor([[[0.8198],




         [0.7529]]]), tensor([], size=(5, 0, 1)), tensor([[[0.9971]],




        [[0.7529]]]), tensor([[[0.6984],





Why do I get empty tensors like tensor([], size=(5, 0, 1)) and how would I achieve what I want to achieve?


  • From your description of the desired result:

    I was also expecting the tensors to have the shape (5, number of elements corresponding to unique_values, 1).

    I believe you are looking for the count (or frequency) of unique values. If you want to keep using torch.unique, then you can provide the return_counts argument combined with a call to torch.cumsum.

    Something like this should work:

    >>> indices = torch.cumsum(counts, dim=0)
    >>> splits = torch.tensor_split(a, indices[:-1], dim = 1)

    Let's have a look:

    >>> for x in splits:
    ...     print(x.shape)
    torch.Size([5, 1, 1])
    torch.Size([5, 3, 1])
    torch.Size([5, 2, 1])
    torch.Size([5, 3, 1])
    torch.Size([5, 1, 1])