I've been having problems getting my data to fit the dimensions required by pytorch GRU.
What this network is going to do is take a 256-long encoded vector representation of a molecule and learn to generate the corresponding SELFIES string (text-based molecule representation), padded to the length of 128, with tokens from an alphabet of 42 'letters'.
Now, i have no idea how to reshape the input tensor for the GRU to accept it as an input, according to the drawing I attached.
Thanks in advance for your help.
I tried to torch.unsqueeze(1) the input tensor. This resulted in me getting an output of shape [64, 1, 256] which would be a batch of 64 one-token outputs in my model.
class DecoderNet(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, output_size, output_len):
super(DecoderNet, self).__init__()
# GRU parameters
self.input_size = input_size # = 256
self.hidden_size = hidden_size # = 256
self.num_layers = num_layers # = 1
# output token count
self.output_size = output_size # = 42
# output length or GRU time steps count
self.output_len = output_len # = 128
# pytorch.nn
self.gru = nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers)
self.fc = nn.Linear(hidden_size, output_size)
self.softmax = nn.Softmax(dim=2)
self.relu = nn.ReLU()
def forward(self, x, h):
out, h = self.gru(x, h)
return out, h
def init_hidden(self, batch_size):
h0 = torch.zeros(self.num_layers, batch_size, self.hidden_size)
return h0
By default, nn.GRU
expects (seq_len, batch_size, input_size)
as input. You need to create the layer with batch_first=True
to give it (batch_size, seq_len, input_size)
.
If your x
has a shape of (batch_size, seq_len)
, then you first need to add the inputs size dimensions with
x = x.unsqueeze(2)
to get a shape of (batch_size, seq_len, input_size=1)
.
Alternatively, you can keep batch_first=False
(the default) and swap the batch size and sequence length dimension, before or after the unsqueeze()
like that:
x = x.transpose(1, 0)
Important: Do not use reshape()
or view()
to "fix" the shape of x
(as indicated by the title question of your post), as this will mess up your tensor!