Search code examples
pytorchcoremlcoremltools

how to Implement replication_pad2d layer in to coremltools converters as torch_op


I've been trying to convert my pytorch model into coreML format, However one of the layers is currently not supported replication_pad2d. Therefore I was trying to implement it using the register operator decorator @register_torch_op to reimplement the layer for coremltools.converters, However I'm struggling to understand the input types to be able to implement the function currently. I got this, which is an implementation roughly translated from pytorch but it's not working

from coremltools.converters.mil import Builder as mb
from coremltools.converters.mil import register_torch_op
from coremltools.converters.mil.frontend.torch.ops import _get_inputs

@register_torch_op
def replication_pad2d(context, node):
  inputs = _get_inputs(context, node)
  x = inputs[0]
  a = len(x)
  L_list, R_list = [], []
  U_list, D_list = [], []
  for i in range(a):#i:0, 1
    l = x[:, :, :, (a-i):(a-i+1)]
    L_list.append(l)
    r = x[:, :, :, (i-a-1):(i-a)]
    R_list.append(r)
  L_list.append(x)
  x = mb.concat(L_list+R_list[::-1], axis=3, name=node.name)
  for i in range(a):
    u = x[:, :, (a-i):(a-i+1), :]
    U_list.append(u)
    d = x[:, :, (i-a-1):(i-a), :]
    D_list.append(d)
  U_list.append(x)
  x = mb.concat(U_list+D_list[::-1], axis=3, name=node.name)
  context.add(x)

but getting the following error

<ipython-input-12-cf14ed84cb93> in replication_pad2d(context, node)
     59   inputs = _get_inputs(context, node)
     60   x = inputs[0]
---> 61   a = len(x)
     62   L_list, R_list = [], []
     63   U_list, D_list = [], []

TypeError: object of type 'Var' has no len()

would be great if someone could help me understand this better especially input type node and context


Solution

  • I think the you can use the existing padding layer as:

    from coremltools.converters.mil import Builder as mb
    from coremltools.converters.mil import register_torch_op
    from coremltools.converters.mil.frontend.torch.ops import _get_inputs
    
    @register_torch_op(torch_alias=["replication_pad2d"])
    def HackedReplication_pad2d(context, node):
        inputs = _get_inputs(context, node)
        x = inputs[0]
        pad = inputs[1].val
        x_pad = mb.pad(x=x, pad=[pad[2], pad[3], pad[0], pad[1]], mode='replicate')
        context.add(x_pad, node.name)
    

    The documentation of the padding operation is not that great, so ordering of padding parameters is a guessing game.