Search code examples
pythondeep-learningpytorchmxnet

How to use the func like torch.nn.functional.conv2d() in mxnet?


I want to do some convolution calculation with input data and a kernel.
In torch, I can write a func:

import torch
def torch_conv_func(x, num_groups):
    batch_size, num_channels, height, width = x.size()
    conv_kernel = torch.ones(num_channels, num_channels, 1, 1)
    
    return torch.nn.functional.conv2d(x, conv_kernel)

It works well and now I need rebuild in MXnet,so I write this:


from mxnet import nd
from mxnet.gluon import nn

def mxnet_conv_func(x, num_groups):
    batch_size, num_channels, height, width = x.shape
    conv_kernel = nd.ones((num_channels, num_channels, 1, 1))

    return nd.Convolution(x, conv_kernel)

And I got the error

mxnet.base.MXNetError: Required parameter kernel of Shape(tuple) is not presented, in operator Convolution(name="")

How to fix it?


Solution

  • You're missing some extra arguments to mxnet.nd.Convolution. You can do it like this:

    from mxnet import nd
    
    def mxnet_convolve(x):
        B, C, H, W = x.shape
        weight = nd.ones((C, C, 1, 1))
        return nd.Convolution(x, weight, no_bias=True, kernel=(1,1), num_filter=C)
    
    x = nd.ones((16, 3, 32, 32))
    mxnet_convolve(x)
    

    Since you're not using a bias, you need to set no_bias to True. Also, mxnet requires you to specify the kernel dimensions with the kernel and num_filter argument.