Given a torch tensor:
# example tensor size 2 x 4
a = torch.Tensor([[1, 2, 3, 4], [5, 6, 7, 8]])
and another where every n rows are repeated:
# example tensor size 4 x 3 where every 2 rows repeated
b = torch.Tensor([[1, 2, 3], [4, 5, 6], [1, 2, 3], [4, 5, 6]])
how can one perform matrix multiplication:
>>> torch.mm(a, b)
tensor([[ 28., 38., 48.],
[ 68., 94., 120.]])
without copying the whole repeated row tensor into memory or iterating?
i.e. only store the first 2 rows:
# example tensor size 2 x 3 where only the first two rows from b are actually stored in memory
b_abbreviated = torch.Tensor([[1, 2, 3], [4, 5, 6]])
since these rows will be repeated.
There is a function
torch.expand()
but this does work when repeating more than a single row, and also, as this question:
Repeating a pytorch tensor without copying memory
indicates and my own tests confirm often ends up copying the whole tensor into memory anyway when calling
.to(device)
It is also possible to do this iteratively, but this is relatively slow.
Is there some way to perform this operation efficiently without storing the whole repeated row tensor in memory?
Edit explanation:
Sorry, for not initially clarifying: One was used as the first dimension of the first tensor to keep the example simple, but I am actually looking for a solution to the general case for any two tensors a and b such that their dimensions are compatible for matrix multiplication and the rows of b repeat every n rows. I have updated the example to reflect this.
Assuming that the first dimension of a
is 1 as in your example, you could do the following:
a = torch.Tensor([[1, 2, 3, 4]])
b_abbreviated = torch.Tensor([[1, 2, 3], [4, 5, 6]])
torch.mm(a.reshape(-1, 2), b_abbreviated).sum(axis=0, keepdim=True)
Here, instead of repeating the rows, you multiply a
in chunks, then add them up column-wise to get the same result.
If the first dimension of a
is not necessarily 1, you could try the following:
torch.cat(torch.split(torch.mm(a.reshape(-1,2),b_abbreviated), a.shape[0]), dim=1).sum(
dim=0, keepdim=True).reshape(a.shape[0], -1)
Here, you do the following:
torch.mm(a.reshape(-1,2),b_abbreviated
, you again split each row of a
into chunks of size 2 and stack them one over the other, and then stack each row over the other.torch.split(torch.mm(a.reshape(-1,2),b_abbreviated), a.shape[0])
, these stacks are then separated row-wise, so that each resultant component of the split corresponds to chunks of a single row.torch.cat(torch.split(torch.mm(a.reshape(-1,2),b_abbreviated), a.shape[0]), dim=1)
these stacks are then concatenated column-wise..sum(dim=0, keepdim=True)
, results corresponding to separate chunks of individual rows in a
are added up..reshape(a.shape[0], -1)
, rows of a
that were concatenated column-wise are again stacked row-wise.It seems quite slow compared to direct matrix multiplication, which is not surprising, but I have not yet checked in comparison to explicit iteration. There are likely better ways of doing this, will edit if I think of any.