Search code examples
pythonmachine-learningpytorchtensor

PyTorch: bitwise OR all elements below a certain dimension


New to pytorch and tensors in general, I could use some guidance :) I'll do my best to write a correct question, but I may use terms incorrectly here and there. Feel free to correct all of this :)

Say I have a tensor of shape (n, 3, 3). Essentially, n matrices of 3x3. Each of these matrices contains either 0 or 1 for each cell.

What's the best (fastest, easiest?) way to do a bitwise OR for all of these matrices?

For example, if I have 3 matrices:

0 0 1
0 0 0
1 0 0

--

1 0 0
0 0 0
1 0 1

--

0 1 1
0 1 0
1 0 1

I want the final result to be

1 1 1
0 1 0
1 0 1

Solution

  • Add all the tensors across the first dimension and check if the sum is above 0:

    import torch
    
    tensor = torch.tensor([[[0, 0, 1],
                           [0, 0, 0],
                           [1, 0, 0]],
                          [[1, 0, 0],
                           [0, 0, 0],
                           [1, 0, 1]],
                          [[0, 1, 1],
                           [0, 1, 0],
                           [1, 0, 1]]])
    
    tensor2 = torch.sum(tensor, axis = 0) > 0
    tensor2 = tensor2.to(torch.uint8)