Search code examples
pythontensorflowtorch

How to translate TF Dense layer to PyTorch?


I am wondering if someone can help me understand how to translate a short TF model into Torch.

Consider this TF setup:

inp = layers.Input(shape = (386, 1024, 1), dtype = tf.float32)
x = layers.Dense(2)(inp)  # [None, 386, 1024, 2]
start, end = tf.split(x, 2, axis=-1)
start = tf.squeeze(start, axis = -1)
end = tf.squeeze(end, axis = -1)
model = Model(inputs = inp, outputs = [start, end])

Specifically, I am not sure what Torch command would turn my data from 386, 1024, 1 to 386, 1024, 2 nor do I understand anything that this says its doing: Model(inputs = inp, outputs = [start, end])

Is:

inp = layers.Input(shape = (386, 1024, 1), dtype = tf.float32)
x = layers.Dense(2)(inp)  # [None, 386, 1024, 2]

Equivalent to:

X = torch.randn(386, 1024, 1)
X = X.expand(386, 1024, 2)
X.shape [386, 1024, 2]

Solution

  • TF -> Torch when build the model is basically straight forward, you can usually find Torch function that equivalent to TF function in PyTorch documentation, following is the example of convert the TF code:

    import tensorflow as tf
    from tensorflow.keras import layers, models
    import numpy as np
    
    inp = layers.Input(shape = (386, 1024, 1), dtype = tf.float32)
    x = layers.Dense(2)(inp)  # [None, 386, 1024, 2]
    start, end = tf.split(x, 2, axis=-1)
    start = tf.squeeze(start, axis=-1)
    end = tf.squeeze(end, axis=-1)
    model = models.Model(inputs = inp, outputs = [start, end])
    
    X = np.random.randn(3, 386, 1024, 1)
    output = model(X)
    print(output[0].shape, output[1].shape)
    
    # Outputs: (3, 386, 1024) (3, 386, 1024)
    

    To Torch code:

    import torch
    from torch import nn
    
    class Net(nn.Module):
        def __init__(self):
          super(Net, self).__init__()
          self.fc = nn.Linear(1, 2)
    
        def forward(self, x):
          x = self.fc(x)
          start, end = torch.split(x, 1, dim=-1)
          start = torch.squeeze(start, dim=-1)
          end = torch.squeeze(end, dim=-1)
          return [start, end]
    
    net = Net()
    
    X = torch.randn(3, 386, 1024, 1)
    output = net(X)
    print(output[0].size(), output[1].size())
    
    # Outputs: torch.Size([3, 386, 1024]) torch.Size([3, 386, 1024])
    

    And the following TF code:

    inp = layers.Input(shape = (386, 1024, 1), dtype = tf.float32)
    x = layers.Dense(2)(inp)  # [None, 386, 1024, 2]
    

    is not equivalent to following Torch code:

    X = torch.randn(386, 1024, 1)
    X = X.expand(386, 1024, 2)
    X.shape [386, 1024, 2]
    

    Since the layers.Dense in TF is equivalent to nn.Linear in Torch