I am trying to implement the attention described in Luong et al. 2015 in PyTorch myself, but I couldn't get it work. Below is my code, I am only interested in the "general" attention case for now. I wonder if I am missing any obvious error. It runs, but doesn't seem to learn.
class AttnDecoderRNN(nn.Module):
def __init__(self, hidden_size, output_size, dropout_p=0.1):
super(AttnDecoderRNN, self).__init__()
self.hidden_size = hidden_size
self.output_size = output_size
self.dropout_p = dropout_p
self.embedding = nn.Embedding(
num_embeddings=self.output_size,
embedding_dim=self.hidden_size
)
self.dropout = nn.Dropout(self.dropout_p)
self.gru = nn.GRU(self.hidden_size, self.hidden_size)
self.attn = nn.Linear(self.hidden_size, self.hidden_size)
# hc: [hidden, context]
self.Whc = nn.Linear(self.hidden_size * 2, self.hidden_size)
# s: softmax
self.Ws = nn.Linear(self.hidden_size, self.output_size)
def forward(self, input, hidden, encoder_outputs):
embedded = self.embedding(input).view(1, 1, -1)
embedded = self.dropout(embedded)
gru_out, hidden = self.gru(embedded, hidden)
# [0] remove the dimension of directions x layers for now
attn_prod = torch.mm(self.attn(hidden)[0], encoder_outputs.t())
attn_weights = F.softmax(attn_prod, dim=1) # eq. 7/8
context = torch.mm(attn_weights, encoder_outputs)
# hc: [hidden: context]
out_hc = F.tanh(self.Whc(torch.cat([hidden[0], context], dim=1)) # eq.5
output = F.log_softmax(self.Ws(out_hc), dim=1) eq. 6
return output, hidden, attn_weights
I have studied the attention implemented in
https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html
and
https://github.com/spro/practical-pytorch/blob/master/seq2seq-translation/seq2seq-translation.ipynb
self.attn = nn.Linear(self.hidden_size * 2, self.max_length)
), which could be expensive for long sequences.tanh
. Besides, it is really slow after updating it to latest version of pytorch (ref). Also I don't know why it takes the last context (ref).This version works, and it follows the definition of Luong Attention (general), closely. The main difference from that in the question is the separation of embedding_size
and hidden_size
, which appears to be important for training after experimentation. Previously, I made both of them the same size (256), which creates trouble for learning, and it seems that the network could only learn half the sequence.
class EncoderRNN(nn.Module):
def __init__(self, input_size, embedding_size, hidden_size,
num_layers=1, bidirectional=False, batch_size=1):
super(EncoderRNN, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bidirectional = bidirectional
self.batch_size = batch_size
self.embedding = nn.Embedding(input_size, embedding_size)
self.gru = nn.GRU(embedding_size, hidden_size, num_layers,
bidirectional=bidirectional)
def forward(self, input, hidden):
embedded = self.embedding(input).view(1, 1, -1)
output, hidden = self.gru(embedded, hidden)
return output, hidden
def initHidden(self):
directions = 2 if self.bidirectional else 1
return torch.zeros(
self.num_layers * directions,
self.batch_size,
self.hidden_size,
device=DEVICE
)
class AttnDecoderRNN(nn.Module):
def __init__(self, embedding_size, hidden_size, output_size, dropout_p=0):
super(AttnDecoderRNN, self).__init__()
self.embedding_size = embedding_size
self.hidden_size = hidden_size
self.output_size = output_size
self.dropout_p = dropout_p
self.embedding = nn.Embedding(
num_embeddings=output_size,
embedding_dim=embedding_size
)
self.dropout = nn.Dropout(self.dropout_p)
self.gru = nn.GRU(embedding_size, hidden_size)
self.attn = nn.Linear(hidden_size, hidden_size)
# hc: [hidden, context]
self.Whc = nn.Linear(hidden_size * 2, hidden_size)
# s: softmax
self.Ws = nn.Linear(hidden_size, output_size)
def forward(self, input, hidden, encoder_outputs):
embedded = self.embedding(input).view(1, 1, -1)
embedded = self.dropout(embedded)
gru_out, hidden = self.gru(embedded, hidden)
attn_prod = torch.mm(self.attn(hidden)[0], encoder_outputs.t())
attn_weights = F.softmax(attn_prod, dim=1)
context = torch.mm(attn_weights, encoder_outputs)
# hc: [hidden: context]
hc = torch.cat([hidden[0], context], dim=1)
out_hc = F.tanh(self.Whc(hc))
output = F.log_softmax(self.Ws(out_hc), dim=1)
return output, hidden, attn_weights