Search code examples
pythonpytorchconvolution

Differentiable convolution between two 1d signals in pytorch


i need to implement a convolution between a signal and a window in pytorch and i want it to be differentiable. Since i couldn't find an already existing function for tensors (i could only find the ones with learnable parameters) i wrote one myself but, i'm unable to make it work without breaking the computation graph. How could i do it? The function i made is:

def Convolve(a, b):
  conv=torch.zeros(a.shape[0], a.shape[1], requires_grad=True).clone()
  l=b.shape[0]
  r=int((l-1)/2)
  l2=a.shape[1]
  for x in range(a.shape[0]):#for evry signal
    for x2 in range(a.shape[1]):#for every time instant
      for x4 in range(l):#compute the convolution (for every window value)
        if (x2-r+x4<0 or x2-r+x4>=l2):#if the index is out of bonds the result is 0 (to avoid zero padding)
           conv[x][x2]+=0
        else:
          conv[x][x2-r+x4]+=a[x][x2-r+x4]*b[x4]#otherwise is window*signal
  return conv

Where 'a' is a two dimensional tensor (signal index, time) and 'b' is an Hann window. The lenght of the window is odd.


Solution

  • it is (fortunately!) possible to achieve this with pytorch primitives. You are probably looking for functional conv1d. Below is how it works. I was not sure whether you wanted the derivative with respect to the input or the weights, so you have both, just keep the requires_grad that fits you needs :

    import torch.nn.functional as F
    # batch of 3 signals, each of length 11 with 1 channel
    signal = torch.randn(3, 1, 11, requires_grad=True) 
    # convolution kernel of size 3, expecting 1 input channel and 1 output channel
    kernel = torch.randn(1, 1, 3, requires_grad=True)
    # convoluting signal with kernel and applying padding
    output = F.conv1d(signal, kernel, stride=1, padding=1, bias=None)
    # output.shape is (3, 1, 11)
    # backpropagating some kind of loss through the computational graph
    output.sum().backward() 
    print(kernel.grad)
    >>> torch.tensor([[[...]]])