B = spec_x.size(0)
H = spec_x.size(1)
T = spec_x.size(2)
# Initialize x tensor with zeros
z = torch.zeros(B, 256, H).to(pitch.device)
# Iterate over each batch element
for b in range(B):
# Iterate over each pitch index
for i in range(256):
# Mask spec_x where pitch equals i
masked_spec_x = spec_x[b].masked_select(pitch[b] == i)
# Compute mean along the time dimension
mean_spec_x = torch.mean(masked_spec_x, dim=0)
# Assign the mean to the corresponding position in x
z[b, i] = mean_spec_x
The above code has 2 tensors, spec_x, and pitch. pitch is B T, it's a 2D tensor and it tells us an index from 0 to 255 corresponding to the pitch of the spectrogram at each frame.
The goal is to build tensor z which is B, 256, H, where H is the hidden size of spec_x.
z[b][i] = average of spec_x[b] where pitch == i
The above code works, but it's very slow because of the loops, I'm just not sure if there's a way to remove the loops using pytorch built ins.
Thanks!
One solution I see is to use a reduce function to distribute the values from spec_x
at indices given by pitch
. The torch.scatter
function seems complex to set up but all you need to do is make sure that
All three tensors (z
, src
, and index
) have the same number of dimensions;
The indexing tensor (index
) has values smaller than the dimension size of the output tensor (z
) at the scattering dimension (dim
).
To accommodate for the dimension different, we can unsqueeze and expand all three tensors. The output tensor z
intermediate shape is (B,C,H,T)
:
>>> z = torch.zeros(B,C,H,T)
>>> index = pitch[:,None,None].expand_as(z)
>>> src = spec_x[:,None].expand_as(z)
The scattering operation will be applied on dim=1
(dimension indexed by integers between [0, 255[
). In pseudo-code, that corresponds to:
# z[b][index[b][c][h][t]][h][t] = src[b][c][h][t]
The first step is to scatter the values:
>>> o = z.scatter(dim=1, index=index, src=src)
A trick to get the correct average computed is to apply the same operation but on a tensor of ones of the same shape as src
:
>>> count = z.scatter(dim=1, index=index, src=torch.ones_like(src))
Then simply sum o
and count
over their last two dimensions and divide o
by the counts:
>>> out = o.sum(dim=(-1,-2)) / count.sum(dim=(-1,-2))
You may notice that the output tensor is not of the desired shape, you can fix that by repeating the hidden state dimension since all values are equal row-wise:
>>> out[:,:,None].repeat(1,1,H)