I am wondering if someone can help me understand how to translate a short TF model into Torch.
Consider this TF setup:
inp = layers.Input(shape = (386, 1024, 1), dtype = tf.float32)
x = layers.Dense(2)(inp) # [None, 386, 1024, 2]
start, end = tf.split(x, 2, axis=-1)
start = tf.squeeze(start, axis = -1)
end = tf.squeeze(end, axis = -1)
model = Model(inputs = inp, outputs = [start, end])
Specifically, I am not sure what Torch command would turn my data from 386, 1024, 1
to 386, 1024, 2
nor do I understand anything that this says its doing: Model(inputs = inp, outputs = [start, end])
Is:
inp = layers.Input(shape = (386, 1024, 1), dtype = tf.float32)
x = layers.Dense(2)(inp) # [None, 386, 1024, 2]
Equivalent to:
X = torch.randn(386, 1024, 1)
X = X.expand(386, 1024, 2)
X.shape [386, 1024, 2]
TF -> Torch when build the model is basically straight forward, you can usually find Torch function that equivalent to TF function in PyTorch documentation, following is the example of convert the TF code:
import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np
inp = layers.Input(shape = (386, 1024, 1), dtype = tf.float32)
x = layers.Dense(2)(inp) # [None, 386, 1024, 2]
start, end = tf.split(x, 2, axis=-1)
start = tf.squeeze(start, axis=-1)
end = tf.squeeze(end, axis=-1)
model = models.Model(inputs = inp, outputs = [start, end])
X = np.random.randn(3, 386, 1024, 1)
output = model(X)
print(output[0].shape, output[1].shape)
# Outputs: (3, 386, 1024) (3, 386, 1024)
To Torch code:
import torch
from torch import nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc = nn.Linear(1, 2)
def forward(self, x):
x = self.fc(x)
start, end = torch.split(x, 1, dim=-1)
start = torch.squeeze(start, dim=-1)
end = torch.squeeze(end, dim=-1)
return [start, end]
net = Net()
X = torch.randn(3, 386, 1024, 1)
output = net(X)
print(output[0].size(), output[1].size())
# Outputs: torch.Size([3, 386, 1024]) torch.Size([3, 386, 1024])
And the following TF code:
inp = layers.Input(shape = (386, 1024, 1), dtype = tf.float32)
x = layers.Dense(2)(inp) # [None, 386, 1024, 2]
is not equivalent to following Torch code:
X = torch.randn(386, 1024, 1)
X = X.expand(386, 1024, 2)
X.shape [386, 1024, 2]
Since the layers.Dense
in TF is equivalent to nn.Linear
in Torch