Search code examples
pythontensorflowpytorch

How can I convert this tensoflow code to pytorch?


How can I convert this tensoflow code to pytorch?

#tensoflow
Conv2D(
    self.filter_1, (1, 64), 
    activation='elu', 
    padding="same",
    kernel_constraint=max_norm(2., axis=(0, 1, 2))
)
nn.Sequential(
        nn.Conv2D(16, (1, 64),
        padding="same",
        kernel_constraint=max_norm(2., axis=(0, 1, 2)),
        nn.ELU()
)

Solution

  • You need two things:

    1. You need to know what the input channel size is. In your example, you've only given the number of output channels, 16. Keras calculates this on its own during runtime, but you have to specify input channels when making torch nn.Conv2d.
    2. You need to implement the max_norm constraint on the conv kernel yourself.

    With this in mind, let's write a simple wrapper around the nn.Conv2d, that just enforces the constraint on the weights each time forward is called:

    import torch
    from torch import nn
    import torch.nn.functional as F
    
    class Conv2D_Norm_Constrained(nn.Conv2d):
        def __init__(self, max_norm_val, norm_dim, **kwargs):
            super().__init__(**kwargs)
            self.max_norm_val = max_norm_val
            self.norm_dim = norm_dim
    
        def get_constrained_weights(self, epsilon=1e-8):
            norm = self.weight.norm(2, dim=self.norm_dim, keepdim=True)
            return self.weight * (torch.clamp(norm, 0, self.max_norm_val) / (norm + epsilon))
    
        def forward(self, input):
            return F.conv2d(input, self.get_constrained_weights(), self.bias, self.stride, self.padding, self.dilation, self.groups)
    

    Assuming your input channels are something like 8, we can write:

    nn.Sequential(
        Conv2D_Norm_Constrained(in_channels=8, out_channels=16, kernel_size=(1, 64), padding="same", max_norm_val=2.0, norm_dim=(0, 1, 2)),
        nn.ELU()
    )