Search code examples
pythonpytorch

Removing the for loops for mean calculation with batch with pytorch


B = spec_x.size(0)
H = spec_x.size(1)
T = spec_x.size(2)

# Initialize x tensor with zeros
z = torch.zeros(B, 256, H).to(pitch.device)

# Iterate over each batch element
for b in range(B):
    # Iterate over each pitch index
    for i in range(256):
        # Mask spec_x where pitch equals i
        masked_spec_x = spec_x[b].masked_select(pitch[b] == i)
        
        # Compute mean along the time dimension
        mean_spec_x = torch.mean(masked_spec_x, dim=0)
        
        # Assign the mean to the corresponding position in x
        z[b, i] = mean_spec_x

The above code has 2 tensors, spec_x, and pitch. pitch is B T, it's a 2D tensor and it tells us an index from 0 to 255 corresponding to the pitch of the spectrogram at each frame.

The goal is to build tensor z which is B, 256, H, where H is the hidden size of spec_x.

z[b][i] = average of spec_x[b] where pitch == i

The above code works, but it's very slow because of the loops, I'm just not sure if there's a way to remove the loops using pytorch built ins.

Thanks!


Solution

  • One solution I see is to use a reduce function to distribute the values from spec_x at indices given by pitch. The torch.scatter function seems complex to set up but all you need to do is make sure that

    1. All three tensors (z, src, and index) have the same number of dimensions;

    2. The indexing tensor (index) has values smaller than the dimension size of the output tensor (z) at the scattering dimension (dim).

    To accommodate for the dimension different, we can unsqueeze and expand all three tensors. The output tensor z intermediate shape is (B,C,H,T):

    >>> z = torch.zeros(B,C,H,T)
    >>> index = pitch[:,None,None].expand_as(z)
    >>> src = spec_x[:,None].expand_as(z)
    

    The scattering operation will be applied on dim=1 (dimension indexed by integers between [0, 255[). In pseudo-code, that corresponds to:

    # z[b][index[b][c][h][t]][h][t] = src[b][c][h][t]
    

    The first step is to scatter the values:

    >>> o = z.scatter(dim=1, index=index, src=src) 
    

    A trick to get the correct average computed is to apply the same operation but on a tensor of ones of the same shape as src:

    >>> count = z.scatter(dim=1, index=index, src=torch.ones_like(src)) 
    

    Then simply sum o and count over their last two dimensions and divide o by the counts:

    >>> out = o.sum(dim=(-1,-2)) / count.sum(dim=(-1,-2))
    

    You may notice that the output tensor is not of the desired shape, you can fix that by repeating the hidden state dimension since all values are equal row-wise:

    >>> out[:,:,None].repeat(1,1,H)