Search code examples

Pytorch: how to (efficiently) apply a function without a “dim” argument to each row of a 2D tensor?

Long story short, I have a 2D matrix of ones and zeros and I need to retrieve, for each row, the indexes of the elements set to one. The “standard” way to do so would be torch.nonzero, but that function is well known for being 1) a real bottleneck, since it does not know in advance the size of the final vector, and 2) it cannot be applied to each row of a 2D tensor in one shot since different rows may have different amounts of ones.

Recently, at::nonzero_static has been introduced, which solves the first point by giving the function the expected maximum number of nonzero elements (which is fine for my application). However, it does not feature a “dim” argument, meaning that it cannot be applied to each row/column individually, which in my opinion makes no sense since setting the size of the output guarantees that each row would feature the same amount of items, thus making the output a tensor.

Using a for loop would obviously solve my issue, but that would mean calling the function several times which is not GPU efficient. Does anyone know a way to apply nonzero_static efficiently to each row, and returning a tensor where each row is the result of its application to each slice of the tensor? From my understanding, vmap may be a solution but I am not sure whether it is optimized for GPU.


  • I implemented a few solutions. A few preliminaries:

    • nonzero_static() is unfortunately not compatible with cuda backend, which may be limiting for your use case
    • vmap will not likely work as it "does not provide general autobatching or handle variable-length sequences out of the box." and creates a batched_tensor output. Running vmap on nonzero_static produces a warning UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::nonzero_static.
    • In general, leaving the result in list-like form (i.e. two 1D tensors with row and column indices, respectively) is faster than putting these indices into a tensor of the original shape of the data, and sorting so that the useful indices are first adds a bit of additional time on top of that.
    • Takeaway from my very crude experiment was that for most reasonable sizes of tensor, vanilla nonzero() was the fastest or nearly as fast as the index-broadcasting solution. Seems that the unclear size of memory to be allocated is not in general a large bottleneck when compared to the relatively clunky workaround solutions. Would be interesting to re-evaluate if either nonzero_static was optimized for batched computation with vmap or CUDA backend was implemented for nonzero_static which hopefully will eventually happen as it's a relatively new function in pytorch.
    import torch
    import time
    m = 2000
    n = 1000
    trials = 100
    results = {}
    for t in range(trials):
        device = torch.device("cpu")
        data = torch.rand([m,n],device = device).round().long()
        # use nonzero 
        name = "nonzero"
        t1 = time.time()
        idx = data.nonzero()
        midx = idx[:,0]
        nidx = idx[:,1]
        output = torch.zeros([m,n],device = device,dtype = torch.long)
        output[midx,nidx] = nidx 
        output = output.sort(dim = 1,descending = True)
            results[name] += time.time()- t1
            results[name] = time.time() - t1
        # use nonzero_static and leave in "listy" form
        name = "nonzero_static"
        t1 = time.time()
        count_nonzero = int(data.sum().item())
        d = data.view(-1)
        idx = d.nonzero_static(size = count_nonzero)
        midx,nidx = idx//n, idx%n
            results[name] += time.time()- t1
            results[name] = time.time() - t1
        # use nonzero_static and put in matrix form, leave unsorted
        name = "nonzero_static -> matrix"
        t1 = time.time()
        count_nonzero = int(data.sum().item())
        d = data.view(-1)
        idx = d.nonzero_static(size = count_nonzero)
        midx,nidx = idx//n, idx%n
        output = torch.zeros([m,n],device = device,dtype = torch.long)
        output[midx,nidx] = nidx 
            results[name] += time.time()- t1
            results[name] = time.time() - t1
        # use nonzero_static and put in matrix form, then sort
        name = "nonzero_static -> sorted matrix"
        t1 = time.time()
        count_nonzero = int(data.sum().item())
        d = data.view(-1)
        idx = d.nonzero_static(size = count_nonzero)
        midx,nidx = idx//n, idx%n
        output = torch.zeros([m,n],device = device,dtype = torch.long)
        output[midx,nidx] = nidx 
        output = output.sort(dim = 1,descending = True)
            results[name] += time.time()- t1
            results[name] = time.time() - t1
        # vmap nonzero_static
        name = "vmap nonzero_static"
        t1 = time.time()
        test = torch.func.vmap(torch.nonzero_static)
        output = test(data,size = n).squeeze(-1)
            results[name] += time.time()- t1
            results[name] = time.time() - t1
        # use index broadcasting then sort
        name = "index broadcasting"
        t1 = time.time()
        index_array = torch.arange(n).unsqueeze(0).expand(m,n)
        output = data*index_array
        output = output.sort(dim = 1,descending = True)
            results[name] += time.time()- t1
            results[name] = time.time() - t1
        device = torch.device("cuda:0")
        data =
        # use index broadcasting then sort on GPU
        name = "GPU index broadcasting"
        t1 = time.time()
        index_array = torch.arange(n,device = device).unsqueeze(0).expand(m,n)
        output = data*index_array
        output = output.sort(dim = 1,descending = True)
            results[name] += time.time()- t1
            results[name] = time.time() - t1
        del output
        #use nonzero and leave in listy form
        name = "GPU nonzero"
        t1 = time.time()
        idx = data.nonzero()
        midx = idx[:,0]
        nidx = idx[:,1]
        output = torch.zeros([m,n],device = device,dtype = torch.long)
        output[midx,nidx] = nidx 
        output = output.sort(dim = 1,descending = True)
            results[name] += time.time()- t1
            results[name] = time.time() - t1
    print("Results for [{},{}] over {} trials".format(m,n,trials))
    for key in results:
        print("{:.5f}s for {}".format(results[key]/trials,key))
    Results for [200,100] over 100 trials
    0.00051s for nonzero
    0.00035s for nonzero_static
    0.00037s for nonzero_static -> matrix
    0.00062s for nonzero_static -> sorted matrix
    0.00191s for vmap nonzero_static
    0.00033s for index broadcasting
    0.00015s for GPU index broadcasting
    0.00019s for GPU nonzero
    Results for [2000,1000] over 100 trials
    0.00575s for nonzero
    0.01028s for nonzero_static
    0.01036s for nonzero_static -> matrix
    0.01302s for nonzero_static -> sorted matrix
    0.03645s for vmap nonzero_static
    0.00466s for index broadcasting
    0.00129s for GPU index broadcasting
    0.00198s for GPU nonzero
    Results for [20000,10000] over 20 trials
    0.67861s for nonzero
    1.10534s for nonzero_static
    1.31800s for nonzero_static -> matrix
    1.66106s for nonzero_static -> sorted matrix
    2.68011s for vmap nonzero_static
    0.55859s for index broadcasting
    0.31346s for GPU index broadcasting
    0.30350s for GPU nonzero