I'm studying about NLP with just simple toy project(just gernerating text) with pytorch. While i'm referencing some example code on online, got a problem i can't understand.
Here are codes (some codes has been omitted and are not completed yet.):
def __init__(self, vocab_size, seq_size, embedding_size, hidden_size):
super(RNNModule, self).__init__()
self.seq_size = seq_size
self.hidden_size = hidden_size
self.embedding = nn.Embedding(vocab_size,
embedding_size)
self.lstm = nn.LSTM(input_size = embedding_size,
hidden_size = hidden_size,
num_layers = 2,
batch_first=True)
self.dense = nn.Linear(hidden_size, vocab_size)
def forward(self, x, prev_state):
embed = self.embedding(x)
output, state = self.lstm(embed, prev_state)
logits = self.dense(output)
print(logits.size())
return logits, state
# 첫 입력값을 위한 zero state를 출력.
def zero_state(self, batch_size):
return (torch.zeros(2, batch_size, self.hidden_size),
torch.zeros(2, batch_size, self.hidden_size))
def make_data_label(corpus) :
data = []
label = []
for c in corpus :
data.append(c[:-1])
label.append(c[1:])
data, label = torch.LongTensor(data), torch.LongTensor(label)
return data, label
if __name__=="__main__":
""" 데이터 불러오기.
"""
corpus, word2id, id2word, weight = load_data()
corpus = torch.LongTensor(corpus)
""" 하이퍼 파라미터.
"""
# 훈련
epochs = 10
learning_rate = 0.003
batch_size = 16
hidden_size = 32 # lstm hidden 값 차원수
gradients_norm=5 # 기울기 클리핑.
# 문장
seq_size=len(corpus[0]) # 문장 1개 길이.
embedding_size=len(weight[0]) # 임베딩 벡터 사이즈.
vocab_size = len(word2id)
# 테스트
# initial_words=['I', 'am']
predict_top_k=5
checkpoint_path='./checkpoint'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('# corpus size : {} / vocab_size : {}'.format(len(corpus), vocab_size))
print('# batch size : {} / num of cell : {}'.format(batch_size, hidden_size))
print("# 디바이스 : ", device)
print('-'*30+"데이터 불러오기 및 하이퍼 파라미터 설정 분할 완료.")
""" data/label 분할
"""
c = corpus.numpy() # corpus가 Tensor 형태이므로 정상적인 slicing을 위해 numpy 형태로 바꾸어준다.
data, label = make_data_label(c)
dataset = CommentDataset(data, label)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
print('-'*30 + "Data Loader 준비 완료.")
""" Model 정의 및 생성.
"""
net = RNNModule(vocab_size, seq_size,
embedding_size, hidden_size)
net = net.to(device)
loss_f = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
for batch_idx, sample in enumerate(dataloader) :
data, label = sample
data = data.to(device)
label = label.to(device)
state_h, state_c = net.zero_state(batch_size) # initial h, c
state_h = state_h.to(device)
state_c = state_c.to(device)
logits, (state_h, state_c) = net.forward(data, (state_h, state_c))
print(logits.transpose(1, 2).size())
print(label.size())
loss = loss_f(logits.transpose(1, 2), label)
loss.backward()
optimizer.step()
break
So, What i can't understand is that why tensor logits
(code at end of the main.py) must be transposed.
shape of logits and label are:
logits : torch.Size([16, 19, 10002]) # [batch_size, setence_length, vocab_size]
label : torch.Size([16, 19]) # [batch_size, setence_length]
In my opinion, to calculate loss with CrossEntropy, label's shape and data' shape must be same dimension, but it doesn't. (label's shape : [batch_size, setence_length] -> [batch_size, setence_length, vocab_size])
How can i understand this? and why does it works?
ps. I referenced below website!
: https://machinetalk.org/2019/02/08/text-generation-with-pytorch/
nn.CrossEntropyLoss()
does not take in one-hot vectors. Rather, it takes in class values. Therefore, your logits and targets will not be of the same dimensions. Logits have to be of the dimension (num_examples, vocab_size)
but your label only has to contain the index of the true class so it will have the shape (num_examples)
not (num_examples, vocab_size)
. That shape would be needed only if you are feeding in one-hot encoded vectors.
As for why you need to transpose your logits vector, nn.CrossEntropyLoss()
expects the logits vector to be of the dimensions (batch_size, num_classes,loss_dims)
where loss_dims here would be the number of tokens in each batch.