New to pytorch and tensors in general, I could use some guidance :) I'll do my best to write a correct question, but I may use terms incorrectly here and there. Feel free to correct all of this :)
Say I have a tensor of shape (n, 3, 3). Essentially, n matrices of 3x3. Each of these matrices contains either 0 or 1 for each cell.
What's the best (fastest, easiest?) way to do a bitwise OR for all of these matrices?
For example, if I have 3 matrices:
0 0 1
0 0 0
1 0 0
--
1 0 0
0 0 0
1 0 1
--
0 1 1
0 1 0
1 0 1
I want the final result to be
1 1 1
0 1 0
1 0 1
Add all the tensors across the first dimension and check if the sum is above 0:
import torch
tensor = torch.tensor([[[0, 0, 1],
[0, 0, 0],
[1, 0, 0]],
[[1, 0, 0],
[0, 0, 0],
[1, 0, 1]],
[[0, 1, 1],
[0, 1, 0],
[1, 0, 1]]])
tensor2 = torch.sum(tensor, axis = 0) > 0
tensor2 = tensor2.to(torch.uint8)