Search code examples
pytorchtensordot

How to compute product between two sets of features in pytorch using a single loop?


I wan to compute the product between two sets of feature matrices X and Y of dimensions (H,W,12) each:

Inefficiently I would do:

H = []
for i in range(12):
    for j in range(12):
        h = X[:,:,i]*Y[:,:,j]
        H.append(h)

which will output H of dimension (H,W,144)

How can this be done in pytorch without iterating in two loops?

I have tried used tensordot solutions but cant replicate the behavior.


Solution

  • I am not sure this is the most efficient, but you can do something like this (warning: ugly code ahead =]):

    import torch
    
    # I choose not to use random -- easier to verify, IMO
    a = torch.Tensor([[[1,2],[3,4],[5,6]],[[1,2],[3,4],[5,6]]])
    b = torch.Tensor([[[1,2],[3,4],[5,6]],[[1,2],[3,4],[5,6]]])
    
    c = torch.bmm(
        a.view(-1, a.size(-1), 1),
        b.view(-1, 1, b.size(-1))
    ).view(*(a.shape[:2]), -1)
    
    print(c)
    
    print(a.shape)
    print(b.shape)
    print(c.shape)
    

    Output:

    tensor([[[ 1.,  2.,  2.,  4.],
             [ 9., 12., 12., 16.],
             [25., 30., 30., 36.]],
    
            [[ 1.,  2.,  2.,  4.],
             [ 9., 12., 12., 16.],
             [25., 30., 30., 36.]]])
    
    torch.Size([2, 3, 2])  # a
    torch.Size([2, 3, 2])  # b
    torch.Size([2, 3, 4])  # c
    

    Basically, the outer product. Let me know if you need me to explain.


    Timings

    While using the torch.bmm, 16 out of 32 cores were being used. I used a GeForce RTX 2080 Ti to run the CUDA version (GPU usage was ~97% during execution). Note that the dimensions used on GPU timings are x10, otherwise it is just too fast.

    Script:

    import timeit
    
    setup = '''
    import torch
    a = torch.randn(({H}, {W}, 12))
    b = torch.randn(({H}, {W}, 12))
    '''
    
    setup_cuda = setup.replace("))", ")).to(torch.device('cuda'))")
    
    bmm = '''
    c = torch.bmm(
        a.view(-1, a.size(-1), 1),
        b.view(-1, 1, b.size(-1))
    ).view(*(a.shape[:2]), -1)
    '''
    
    loop = '''
    c = []
    for i in range(a.size(-1)):
        for j in range(b.size(-1)):
            c.append(a[:, :, i] * b[:, :, j])
    c = torch.stack(c).permute(1, 2, 0)
    '''
    
    min_dim = 10
    max_dim = 100
    num_repeats = 10
    
    print('BMM')
    for d in range(min_dim, max_dim+1, 10):
        print(d, min(timeit.Timer(bmm, setup=setup.format(H=d, W=d)).repeat(num_repeats, 1000)))
    
    print('LOOP')
    for d in range(min_dim, max_dim+1, 10):
        print(d, min(timeit.Timer(loop, setup=setup.format(H=d, W=d)).repeat(num_repeats, 1000)))
    
    print('BMM - CUDA')
    for d in range(min_dim*10, (max_dim*10)+1, 100):
        print(d, min(timeit.Timer(bmm, setup=setup_cuda.format(H=d, W=d)).repeat(num_repeats, 1000)))
    

    Output:

    BMM
    10 0.022082214010879397
    20 0.034024904016405344
    30 0.08957623899914324
    40 0.1376199919031933
    50 0.20248223491944373
    60 0.2657837320584804
    70 0.3533527449471876
    80 0.42361779196653515
    90 0.6103016039123759
    100 0.7161333339754492
    
    LOOP
    10 1.7369094720343128
    20 1.8517447559861466
    30 1.9145489090587944
    40 2.0530637570191175
    50 2.2066439649788663
    60 2.394576688995585
    70 2.6210166650125757
    80 2.9242434420157224
    90 3.5709626079769805
    100 5.413458575960249
    
    BMM - CUDA
    100 0.014253990724682808
    200 0.015094103291630745
    300 0.12792395427823067
    400 0.307440347969532
    500 0.541196970269084
    600 0.8697826713323593
    700 1.2538292426615953
    800 1.6859236396849155
    900 2.2016236428171396
    1000 2.764942280948162