Search code examples
pytorch

Is there a bitcount operation in pytorch


I try to write a simple hamming distance function in pytorch:

xorimg = torch.bitwise_xor(img1,img2)
for i in range(bitlen) : 
        hdist = hdist + (xorimg & 1)
        xorimg = xorimg >> 1

I wonder if there is a simple bitcount operation to count the 1s bits to get rid of the for loop

such as:

xorimg = torch.bitwise_xor(img1,img2)
hdist = torch.bitcount(xorimg)

or any other equivalent way to get rid of time consuming for-loop?

or pytorch support hamming distance directly such as:

hdist = torch.hamming(img1,img2)

that would be even better.

Thanks for your help.


Solution

  • see answer here

    if want to use cuda intrinsic function __popc() in cupy part, then replace the following while loop (bitcount):

    while(x != 0){
      x = x & (x - 1);
      dist[elem_idx]++;
    }
    

    with cuda instrinsic function for int32:

    dist[elem_idx] = __popc(x);
    

    or with a little bit modification in both torch.py and cupy.py for int64:

    dist[elem_idx] = __popcll(x);