Search code examples
pythonpytorchconvolutionimagefilter

Separable convolutions in PyTorch (i.e. 2 1D-vector-tensor "traditional" convolutions)


I'm trying to implement an image filter in PyTorch that takes in two filters of shapes (1,3), (3,1) that build up a filter of (3,3). An example application of this is the Sobel filter or Gaussian blurring

I have a NumPy implementation ready, but PyTorch has a different way of working with convolutions that makes it hard to wrap my head around for more traditional applications such as this. How should I proceed?

def decomposed_conv2d(arr,x_kernel,y_kernel):
  """
  Apply two 1D kernels as a part of a 2D convolution.
  The kernels must be the decomposed from a 2D kernel
  that originally is intended to be convolved with the array.
  Inputs:
  - x_kernel: Column vector kernel, to be applied along the x axis (axis 0)
  - y_kernel: Row vector kernel, to be applied along the y axis (axis 1)
  """
  arr = np.apply_along_axis(lambda x: np.convolve(x, x_kernel, mode='same'), 0, arr)
  arr = np.apply_along_axis(lambda x: np.convolve(x, y_kernel, mode='same'), 1, arr)
  return arr

Gaussian blurring example:

ax = np.array([-1.,0.,1.])
stdev = 0.5
kernel = np.exp(-0.5 * np.square(ax) / np.square(stdev)) / (stdev * np.sqrt(2*np.pi))
decomposed_conv2d(np.arange(9).reshape((3,3)),kernel,kernel)
>>>array([[0.39126886, 1.24684326, 1.83682264],
       [2.86471127, 4.11155453, 4.48257929],
       [4.7279302 , 6.1004473 , 6.17348398]])

(Note: The total "energy" of this array may not be preserved, especially in small arrays like this because the convolution is discrete. It isn't that critical to this particular problem).

Attempting to do the same in PyTorch following this discussion yields an error:

... # define ax,stdev,kernel, etc.
arr_in = torch.arange(9).reshape(3,3) # for example
arr = arr_in.double().unsqueeze(0) # tried both axes and not unsqueezing as well
x_kernel = torch.from_numpy(kernel)
y_kernel = torch.from_numpy(kernel)

x_kernel = x_kernel.view(1,1,-1)
y_kernel = y_kernel.view(1,1,-1)
arr = F.conv1d(arr,x_kernel,padding=x_kernel.shape[2]//2).squeeze(0)
arr = F.conv1d(arr.transpose(0,1),y_kernel, padding=y_kernel.shape[2] // 2).squeeze(0).transpose(2,1).squeeze(1)

>>> RuntimeError: Given groups=1, weight of size [1, 1, 3], expected input[1, 3, 3] to have 1 channels, but got 3 channels instead

I've juggled with squeezes and unsqueezes so that the dimensions match but I still can't get it to do what I want. I just can't even get the first convolution done this way.


Solution

  • Solution with conv2d

    You can make your life a lot easier by using conv2d rather than conv1d.

    Although we use conv2d below, this is still a 1-d convolution (or rather, two 1-d convolutions) effectively, since we apply a 1×n kernel. Thus, we still have all benefits of a separable convolution (in particular, 2·n rather than n² multiplications per pixel for a kernel of length n).

    import numpy as np
    import torch
    from torch.nn.functional import conv2d
    np.set_printoptions(precision=3)  # For better legibility: show fewer float digits
    
    def decomposed_conv2d_np(arr, x_kernel, y_kernel):  # From the question
        arr = np.apply_along_axis(lambda x: np.convolve(x, x_kernel, mode='same'), 0, arr)
        arr = np.apply_along_axis(lambda x: np.convolve(x, y_kernel, mode='same'), 1, arr)
        return arr
    
    def decomposed_conv2d_torch(arr, x_kernel, y_kernel):  # Proposed
        arr = arr.unsqueeze(0).unsqueeze_(0)  # Make copy, make 4D for ``conv2d()``
        arr = conv2d(arr, weight=x_kernel.view(1, 1, -1, 1), padding='same')
        arr = conv2d(arr, weight=y_kernel.view(1, 1, 1, -1), padding='same')
        return arr.squeeze_(0).squeeze_(0)  # Make 2D again
    
    ax = np.array([-1.,0.,1.])
    stdev = 0.5
    kernel = np.exp(-0.5 * np.square(ax) / np.square(stdev)) / (stdev * np.sqrt(2 * np.pi))
    array = np.arange(9).reshape((3,3))
    
    print(result_np := decomposed_conv2d_np(array, kernel, kernel))
    # [[0.391 1.247 1.837]
    #  [2.865 4.112 4.483]
    #  [4.728 6.1   6.173]]
    
    array, kernel = torch.from_numpy(array).to(torch.float64), torch.from_numpy(kernel)
    print(result_torch := decomposed_conv2d_torch(array, kernel, kernel).numpy())
    # [[0.391 1.247 1.837]
    #  [2.865 4.112 4.483]
    #  [4.728 6.1   6.173]]
    
    assert np.allclose(result_np, result_torch)
    

    This solution is based on my answer to a related, earlier question that asked for an implementation of a Gaussian kernel in PyTorch.

    Solution with conv1d

    Here is the corresponding solution using conv1d instead:

    from torch.nn.functional import conv1d
    ...
    def decomposed_conv2d_with_conv1d(a, x_kernel, y_kernel):
        a = a.unsqueeze(1)  # Unsqueeze channels dimension for ``conv1d()``
        a = conv1d(a, weight=y_kernel.view(1, 1, -1), padding='same')  # Use y kernel
        a = a.transpose(0, -1)  # Swap image dims for using x kernel along last dim
        a = conv1d(a, weight=x_kernel.view(1, 1, -1), padding='same')  # Use x kernel
        return a.squeeze_(1).T  # Make 2D again, reestablish original order of dims
    

    The key ideas here are:

    • We always need to convolve along the last dimension, so before convolving with the appropriate kernel, we need to move the corresponding image dimension there.
    • For the remaining image dimension, we can "misuse" what conv1d assumes as the batch dimension (dimension 0) to hold its values. What does not work here is using the channels dimension (dimension 1), since we would need to adjust the kernel by repeating it to match the number of channels. We simply keep the channels dimension at 1 here (meaning we have one image channel), but we could use it for the actual image channels if we had a multichannel image (say, RGB).

    To me, it appears less straightforward than the conv2d solution, since it also involves the reordering of image dimensions. As to performance, I don't know which version is faster and I did not time them. This should be pretty easy to find out; however, what I assume is that performance differences should be negligible.