Search code examples
pythonpytorchcosine-similarityarray-broadcastingnumpy-einsum

In PyTorch, how can I avoid an expensive broadcast when adding two tensors then immediately collapsing?


I have two 2-d tensors, which align via broadcasting, so if I add/subtract them, I incur a huge 3-d tensor. I don't really need that though, since I'll be performing a mean on one dimension. In this demo, I unsqueeze the tensors to show how they align, but they are 2-d otherwise.

x = torch.tensor(...)              # (batch , 1,  B)
y = torch.tensor(...)              # (1,    , A,  B)
out = torch.cos(x - y).mean(dim=2) # (batch, B)

Possible Solutions:

  • An algebraic simplification, but for the life of me I haven't solved this yet.

  • Some PyTorch primitive that'll help? This is cosine similarity, but, a bit different than torch.cosine_similarity. I'm applying it to complex numbers' .angle()s.

  • Custom C/CPython code that loops efficiently.

  • Other?


Solution

  • To save memory I recommend using torch.einsum: We can make use of the trigonometric identity

    cos(x-y) = cos(x)*cos(y) + sin(x)*sin(y)
    

    In this case we can apply einsum where the usual summing will be the averaging, and the + between the two produces will be another operation later, so in short

    xs, ys = torch.sin(x), torch.sin(y)
    xc, yc = torch.cos(x), torch.cos(y)
    # use einsum for sin/cos products and averaging sum, use + for sum of products: 
    out = (torch.einsum('i k, j k -> i k', xs, ys) + torch.einsum('i k, j k -> i k', xc, yc)) / y.shape[1]
    

    While measuring the memory consumption is a little bit tedious, I resorted to just measuring time as a proxy. Here you can see your original method and my proposal for various sizes of inputs. (The script for generating these plots is attached below.)

    enter image description here

    import matplotlib.pyplot as plt
    import torch
    import time
    
    def main():
        ns = torch.logspace(1, 3.2, 20).to(torch.long)
        tns = []; tes = []
        for n in ns:
            tn, te = compare(n)
            tns.append(tn); tes.append(te)
        plt.loglog(ns, tns, ':.'); plt.loglog(ns, tes, '.-'); plt.loglog(ns, 1e-6*ns**1, ':'); plt.loglog(ns, 1e-6*ns**2, ':'); plt.legend(['naive', 'einsum', 'x^1', 'x^2']);
        plt.show()
    
    def compare(n):
        batch = a = b = n
        x = torch.zeros((batch, b)) # (batch , 1,  B)
        y = torch.zeros((a, b))  # (1,    , A,  B)
        t = time.perf_counter(); ra = af(x.unsqueeze(1), y.unsqueeze(0)); print('naive method', tn := time.perf_counter() - t)
        t = time.perf_counter(); rb = bf(x, y); print('einsum method', te := time.perf_counter() - t)
        print((ra-rb).abs().max()) # verify we have same results
        return tn, te
    
    def af(x, y):
        return torch.cos(x - y).mean(dim=2) 
    
    def bf(x, y):
        xs, ys = torch.sin(x), torch.sin(y)
        xc, yc = torch.cos(x), torch.cos(y)
        return (torch.einsum('i k, j k -> i k', xs, ys) + torch.einsum('i k, j k -> i k', xc, yc)) / y.shape[1]
    
    main()