I have two torch
tensors a
and b
:
import torch
torch.manual_seed(0) # for reproducibility
a = torch.rand(size = (5, 10, 1))
b = torch.tensor([3, 3, 1, 5, 3, 1, 0, 2, 1, 2])
I want to split the 2nd dimension of a
(which is dim = 1 in the Python
language) based on the unique values in b
.
What I have tried so far:
# find the unique values and unique indices of b
unique_values, unique_indices = torch.unique(b, return_inverse = True)
# split a in where dim = 1, based on unique indices
l = torch.tensor_split(a, unique_indices, dim = 1)
I was expecting l
to be a list of n number of tensors where n is the number of unique values in b
. I was also expecting the tensors to have the shape (5, number of elements corresponding to unique_values, 1).
However, I get the following:
print(l)
(tensor([[[0.8198],
[0.9971],
[0.6984]],
[[0.7262],
[0.7011],
[0.2038]],
[[0.1147],
[0.3168],
[0.6965]],
[[0.0340],
[0.9442],
[0.8802]],
[[0.6833],
[0.7529],
[0.8579]]]), tensor([], size=(5, 0, 1)), tensor([], size=(5, 0, 1)), tensor([[[0.9971],
[0.6984],
[0.5675]],
[[0.7011],
[0.2038],
[0.6511]],
[[0.3168],
[0.6965],
[0.9143]],
[[0.9442],
[0.8802],
[0.0012]],
[[0.7529],
[0.8579],
[0.6870]]]), tensor([], size=(5, 0, 1)), tensor([], size=(5, 0, 1)), tensor([], size=(5, 0, 1)), tensor([[[0.8198],
[0.9971]],
[[0.7262],
[0.7011]],
[[0.1147],
[0.3168]],
[[0.0340],
[0.9442]],
[[0.6833],
[0.7529]]]), tensor([], size=(5, 0, 1)), tensor([[[0.9971]],
[[0.7011]],
[[0.3168]],
[[0.9442]],
[[0.7529]]]), tensor([[[0.6984],
[0.5675],
[0.8352],
[0.2056],
[0.5932],
[0.1123],
[0.1535],
[0.2417]],
[[0.2038],
[0.6511],
[0.7745],
[0.4369],
[0.5191],
[0.6159],
[0.8102],
[0.9801]],
[[0.6965],
[0.9143],
[0.9351],
[0.9412],
[0.5995],
[0.0652],
[0.5460],
[0.1872]],
[[0.8802],
[0.0012],
[0.5936],
[0.4158],
[0.4177],
[0.2711],
[0.6923],
[0.2038]],
[[0.8579],
[0.6870],
[0.0051],
[0.1757],
[0.7497],
[0.6047],
[0.1100],
[0.2121]]]))
Why do I get empty tensors like tensor([], size=(5, 0, 1))
and how would I achieve what I want to achieve?
From your description of the desired result:
I was also expecting the tensors to have the shape
(5, number of elements corresponding to unique_values, 1)
.
I believe you are looking for the count (or frequency) of unique values. If you want to keep using torch.unique
, then you can provide the return_counts
argument combined with a call to torch.cumsum
.
Something like this should work:
>>> indices = torch.cumsum(counts, dim=0)
>>> splits = torch.tensor_split(a, indices[:-1], dim = 1)
Let's have a look:
>>> for x in splits:
... print(x.shape)
torch.Size([5, 1, 1])
torch.Size([5, 3, 1])
torch.Size([5, 2, 1])
torch.Size([5, 3, 1])
torch.Size([5, 1, 1])