Search code examples
pythonfor-looppytorch

Merging two tensors with channel norm


I want to merge 3-D tensors with channel norm.

My approach is as below :

B, C, H, W = x.size()
z = torch.zeros_like(x)

x_norm = torch.norm(x, dim=(2,3))
y_norm = torch.norm(y, dim=(2,3))

for b in range(B):
    for c in range(C): 
        if x_norm[b,c] >= y_norm[b,c]:
            z[b,c] = x[b,c]
        else:
            z[b,c] = y[b,c]
    

But this method is too slow because of uses the two for loop ...

How can I modify the code to process faster?


Solution

  • You can do it by creating a boolean mask for your condition:

    import torch
    
    x = torch.rand(20, 30, 40, 50)
    y = torch.rand(20, 30, 40, 50)
    
    B, C, H, W = x.size()
    z = torch.zeros_like(x)
    
    x_norm = torch.norm(x, dim=(2, 3))
    y_norm = torch.norm(y, dim=(2, 3))
    
    condition = x_norm >= y_norm  # Create a boolean tensor indicating the condition
    
    # Use the condition to assign values to z without loops
    z[condition] = x[condition]
    z[~condition] = y[~condition]
    

    Though the question does not clarify that, I am assuming that x and y have the same shape.