Search code examples
pytorchlinear-algebraconvolution

How do I merge 2D Convolutions in PyTorch?


From linear algebra we know that linear operators are associative.

In the deep learning world, this concept is used to justify the introduction of non-linearities between NN layers, to prevent a phenomenon colloquially known as linear lasagna, (reference).

In signal processing this also leads to a well known trick to optimize memory and/or runtime requirements (reference).

So merging convolutions is a very useful tool from different perspectives. How to implement it with PyTorch?


Solution

  • If we have y = x * a * b (where * means convolution and a, b are your kernels), we can define c = a * b such that y = x * c = x * a * b as follows:

    import torch
    
    def merge_conv_kernels(k1, k2):
        """
        :input k1: A tensor of shape ``(out1, in1, s1, s1)``
        :input k2: A tensor of shape ``(out2, in2, s2, s2)``
        :returns: A tensor of shape ``(out2, in1, s1+s2-1, s1+s2-1)``
          so that convolving with it equals convolving with k1 and
          then with k2.
        """
        padding = k2.shape[-1] - 1
        # Flip because this is actually correlation, and permute to adapt to BHCW
        k3 = torch.conv2d(k1.permute(1, 0, 2, 3), k2.flip(-1, -2),
                          padding=padding).permute(1, 0, 2, 3)
        return k3
    

    To illustrate the equivalence, this example combines two kernels with 900 and 5000 parameters respectively into an equivalent kernel of 28 parameters:

    # Create 2 conv. kernels
    out1, in1, s1 = (100, 1, 3)
    out2, in2, s2 = (2, 100, 5)
    kernel1 = torch.rand(out1, in1, s1, s1, dtype=torch.float64)
    kernel2 = torch.rand(out2, in2, s2, s2, dtype=torch.float64)
    
    # propagate a random tensor through them. Note that padding
    # corresponds to the "full" mathematical operation (s-1)
    b, c, h, w = 1, 1, 6, 6
    x = torch.rand(b, c, h, w, dtype=torch.float64) * 10
    c1 = torch.conv2d(x, kernel1, padding=s1 - 1)
    c2 = torch.conv2d(c1, kernel2, padding=s2 - 1)
    
    # check that the collapsed conv2d is same as c2:
    kernel3 = merge_conv_kernels(kernel1, kernel2)
    c3 = torch.conv2d(x, kernel3, padding=kernel3.shape[-1] - 1)
    print(kernel3.shape)
    print((c2 - c3).abs().sum() < 1e-5)
    

    Note: The equivalence is assuming that we have unlimited numerical resolution. I think there was research on stacking many low-resolution-float linear operations and showing that the networks profited from numerical error, but I am unable to find it. Any reference is appreciated!