Search code examples
pythonpytorchtensor

How to split multi-dimensional arrays based on the unique indices of another array?


I have two torch tensors a and b:

import torch
torch.manual_seed(0) # for reproducibility

a = torch.rand(size = (5, 10, 1))
b = torch.tensor([3, 3, 1, 5, 3, 1, 0, 2, 1, 2])

I want to split the 2nd dimension of a (which is dim = 1 in the Python language) based on the unique values in b.

What I have tried so far:

# find the unique values and unique indices of b
unique_values, unique_indices = torch.unique(b, return_inverse = True)

# split a in where dim = 1, based on unique indices
l = torch.tensor_split(a, unique_indices, dim = 1)

I was expecting l to be a list of n number of tensors where n is the number of unique values in b. I was also expecting the tensors to have the shape (5, number of elements corresponding to unique_values, 1).

However, I get the following:

print(l)

(tensor([[[0.8198],
         [0.9971],
         [0.6984]],

        [[0.7262],
         [0.7011],
         [0.2038]],

        [[0.1147],
         [0.3168],
         [0.6965]],

        [[0.0340],
         [0.9442],
         [0.8802]],

        [[0.6833],
         [0.7529],
         [0.8579]]]), tensor([], size=(5, 0, 1)), tensor([], size=(5, 0, 1)), tensor([[[0.9971],
         [0.6984],
         [0.5675]],

        [[0.7011],
         [0.2038],
         [0.6511]],

        [[0.3168],
         [0.6965],
         [0.9143]],

        [[0.9442],
         [0.8802],
         [0.0012]],

        [[0.7529],
         [0.8579],
         [0.6870]]]), tensor([], size=(5, 0, 1)), tensor([], size=(5, 0, 1)), tensor([], size=(5, 0, 1)), tensor([[[0.8198],
         [0.9971]],

        [[0.7262],
         [0.7011]],

        [[0.1147],
         [0.3168]],

        [[0.0340],
         [0.9442]],

        [[0.6833],
         [0.7529]]]), tensor([], size=(5, 0, 1)), tensor([[[0.9971]],

        [[0.7011]],

        [[0.3168]],

        [[0.9442]],

        [[0.7529]]]), tensor([[[0.6984],
         [0.5675],
         [0.8352],
         [0.2056],
         [0.5932],
         [0.1123],
         [0.1535],
         [0.2417]],

        [[0.2038],
         [0.6511],
         [0.7745],
         [0.4369],
         [0.5191],
         [0.6159],
         [0.8102],
         [0.9801]],

        [[0.6965],
         [0.9143],
         [0.9351],
         [0.9412],
         [0.5995],
         [0.0652],
         [0.5460],
         [0.1872]],

        [[0.8802],
         [0.0012],
         [0.5936],
         [0.4158],
         [0.4177],
         [0.2711],
         [0.6923],
         [0.2038]],

        [[0.8579],
         [0.6870],
         [0.0051],
         [0.1757],
         [0.7497],
         [0.6047],
         [0.1100],
         [0.2121]]]))

Why do I get empty tensors like tensor([], size=(5, 0, 1)) and how would I achieve what I want to achieve?


Solution

  • From your description of the desired result:

    I was also expecting the tensors to have the shape (5, number of elements corresponding to unique_values, 1).

    I believe you are looking for the count (or frequency) of unique values. If you want to keep using torch.unique, then you can provide the return_counts argument combined with a call to torch.cumsum.

    Something like this should work:

    >>> indices = torch.cumsum(counts, dim=0)
    >>> splits = torch.tensor_split(a, indices[:-1], dim = 1)
    

    Let's have a look:

    >>> for x in splits:
    ...     print(x.shape)
    torch.Size([5, 1, 1])
    torch.Size([5, 3, 1])
    torch.Size([5, 2, 1])
    torch.Size([5, 3, 1])
    torch.Size([5, 1, 1])