Search code examples
luaneural-networkconvolutiontorch

Custom Spatial Convolution In Torch


I need to perform a custom spatial convolution in Torch. Rather than simply multiplying each input pixel by a weight for that pixel and adding them together with the filter's bias to form each output pixel, I need to do a more complex mathematical function to the input pixels before adding them together.

I know how to do this, but I do not know a GOOD way to do this. The best way I've come up with is to take the full input tensor, create a bunch of secondary tensors that are "views" of the original without allocating additional memory, putting those into a Replicate layer (the output filter count being the replication count), and feeding that into a ParallelTable layer containing a bunch of regular layers that have their parameters shared between filters.

The trouble is, even though this is fine memory-wise with a very manageable overhead, we're talking inputwidth^inputheight^inputdepth^outputdepth mini-networks, here. Maybe there's some way to create massive "long and tall" networks that work on the entire replicated input set at once, but how do I create layers that are partially-connected (like convolutions) instead of fully-connected?

I would have liked to just use inheritance to create a special copy of the regular SpatialConvolution "class" and modify it, but I can't even try because it's implemented in an external C library. I can't just use regular layers before a regular SpatialConvolution layer because I need to do my math with different weights and biases for each filter (shared between applications of the same filter to different input coordinates).


Solution

  • Good question. You made me give some serious thought. Your approach has a flaw: it does not allow to take advantage of vectorized computations since each mini-network works independently.

    My idea is as follows:

    Suppose network's input and output are 2D tensors. We can produce (efficiently, without memory copying) an auxiliary 4D tensor rf_input (kernel_size x kernel_size x output_h x output_w) such that rf_input[:, :, k, l] is a 2D tensor of size kernel_size x kernel_size containing a receptive field which output[k, l] will be gotten from. Then we iterate over positions inside the kernel rf_input[i, j, :, :] getting pixels at position (i, j) inside all receptive fields and computing their contribution to each output[k, l] at once using vectorization.

    Example:

    Let our "convolving" function be, for example, a product of tangents of sums:

    enter image description here

    Then its partial derivative w.r.t. the input pixel at position (s,t) in its receptive field is

    enter image description here

    Derivative w.r.t. weight is the same.

    At the end, of course, we must sum up gradients from different output[k,l] points. For example, each input[m, n] contributes to at most kernel_size^2 outputs as a part of their receptive fields, and each weight[i, j] contributes to all output_h x output_w outputs.

    Simple implementation may look like this:

    require 'nn'
    local CustomConv, parent = torch.class('nn.CustomConv', 'nn.Module')
    
    -- This module takes and produces a 2D map. 
    -- To work with multiple input/output feature maps and batches, 
    -- you have to iterate over them or further vectorize computations inside the loops.
    
    function CustomConv:__init(ker_size)
        parent.__init(self)
    
        self.ker_size = ker_size
        self.weight = torch.rand(self.ker_size, self.ker_size):add(-0.5)
        self.gradWeight = torch.Tensor(self.weight:size()):zero()
    end
    
    function CustomConv:_get_recfield_input(input)
        local rf_input = {}
        for i = 1, self.ker_size do
            rf_input[i] = {}
            for j = 1, self.ker_size do
                rf_input[i][j] = input[{{i, i - self.ker_size - 1}, {j, j - self.ker_size - 1}}]
            end
        end
        return rf_input
    end
    
    function CustomConv:updateOutput(_)
        local output = torch.Tensor(self.rf_input[1][1]:size())
        --  Kernel-specific: our kernel is multiplicative, so we start with ones
        output:fill(1)                                              
        --
        for i = 1, self.ker_size do
            for j = 1, self.ker_size do
                local ker_pt = self.rf_input[i][j]:clone()
                local w = self.weight[i][j]
                -- Kernel-specific
                output:cmul(ker_pt:add(w):tan())
                --
            end
        end
        return output
    end
    
    function CustomConv:updateGradInput_and_accGradParameters(_, gradOutput)
        local gradInput = torch.Tensor(self.input:size()):zero()
        for i = 1, self.ker_size do
            for j = 1, self.ker_size do
                local ker_pt = self.rf_input[i][j]:clone()
                local w = self.weight[i][j]
                -- Kernel-specific
                local subGradInput = torch.cmul(gradOutput, torch.cdiv(self.output, ker_pt:add(w):tan():cmul(ker_pt:add(w):cos():pow(2))))
                local subGradWeight = subGradInput
                --
                gradInput[{{i, i - self.ker_size - 1}, {j, j - self.ker_size - 1}}]:add(subGradInput)
                self.gradWeight[{i, j}] = self.gradWeight[{i, j}] + torch.sum(subGradWeight)
            end
        end
        return gradInput
    end
    
    function CustomConv:forward(input)
        self.input = input
        self.rf_input = self:_get_recfield_input(input)
        self.output = self:updateOutput(_)
        return self.output
    end
    
    function CustomConv:backward(input, gradOutput)
        gradInput = self:updateGradInput_and_accGradParameters(_, gradOutput)
        return gradInput
    end
    

    If you change this code a bit:

    updateOutput:                                             
        output:fill(0)
        [...]
        output:add(ker_pt:mul(w))
    
    updateGradInput_and_accGradParameters:
        local subGradInput = torch.mul(gradOutput, w)
        local subGradWeight = torch.cmul(gradOutput, ker_pt)
    

    then it will work exactly as nn.SpatialConvolutionMM with zero bias (I've tested it).