Search code examples
pytorchattention-model

(Efficiently) expanding a feature mask tensor to match embedding dimensions


I have a B (batch size) by F (feature count) mask tensor M, which I'd like to apply (element-wise multiply) to an input x.

...Thing is, my x has had its raw feature columns transformed into embeddings of non-constant width, so its total dimensions are B by E (embedded size).

My draft code is along the lines of:

# Given something like:
M = torch.Tensor([[0.2, 0.8], [0.5, 0.5], [0.6, 0.4]])  # B=3, F=2
x = torch.Tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 0], [11, 12, 13, 14, 15]])  # E=5

feature_sizes = [2, 3]  # (Feature 0 embedded to 2 cols, feature 1 to 3)

# In forward() pass:
components = []
for ix, size in enumerate(feature_sizes):
    components.append(M[:, ix].view(-1, 1).expand(-1, size))
M_x = torch.cat(components, dim=1)

# Now M_x is (B, E) and can be mapped with x

> M_x = torch.Tensor([
>     [0.2, 0.4, 2.4, 3.6, 4],
>     [3, 3.5, 4, 4.5, 0], 
>     [6.6, 7.2, 5.2, 5.6, 6],
> ])

My question is: are there any obvious optimizations I'm missing here? Is that for loop the right way to go about it, or is there some more direct way to achieve this?

I have control over the embedding process, so can store whatever representation is helpful and am e.g. not tied to feature_sizes list of ints.


Solution

  • Duh, I forgot: An indexing operation could do this!

    Given the above situation (but I'll take a more complex feature_sizes to show a bit more clearly), we can pre-compute an index tensor with something like:

    # Given that:
    feature_sizes = [1, 3, 1, 2]
    
    # Produce nested list e.g. [[0], [1, 1, 1], [2], [3, 3]]:
    ixs_per_feature = [[ix] * size for ix, size in enumerate(feature_sizes)]
    
    # Flatten out into a vector e.g. [0, 1, 1, 1, 2, 3, 3]:
    mask_ixs = torch.LongTensor(
        [item for sublist in ixs_per_feature for item in sublist]
    )
    
    # Now can directly produce M_x by indexing M:
    M_x = M[:, mask_ixs]
    

    I got a modest speedup by using this method instead of the for loops.