Search code examples
pythondeep-learningpytorch

PyTorch Embedding: Expected Long/Int indices, got FloatTensor


Even though the input to embedding layer is type torch.int64 I am still getting the error expected argument of scalar type long or int, but got torch.FloatTensor.

for epoch in range(inital_epoch, config[‘number_epochs’]):
    model.train()
    batch_iterator = tqdm(train_dataloader, desc = f’Processing epoch {epoch:02d}')
    for batch in batch_iterator:
        encoder_input = batch[‘encoder_input’].to(device) # (batch_size, context_len)
        decoder_input = batch[‘decoder_input’].to(device) # (batch_size, context_len)
        encoder_mask = batch[‘encoder_mask’].to(device) # (batch_size, 1, context_len)
        decoder_mask = batch[‘decoder_mask’].to(device) # (batch_size, context_len, context_len)

        print(f"encoder_input dtype: {encoder_input.dtype}")
        print(f"encoder_input shape: {encoder_input.shape}")

        encoder_output = model.encoder(encoder_input, encoder_mask)

In above code, I have added print statement to inspect the dtype and shape of the encoder_input.

When I execute the file I’m getting below error:

Using device cpu
Processing epoch 00:   0%|                                                                                                       | 0/12771 [00:00<?, ?it/s]
encoder_input dtype: torch.int64
encoder_input shape: torch.Size([8, 440])
Processing epoch 00:   0%|                                                                                                       | 0/12771 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "E:\Fahad\College\BE\Project\MT using Transformers\train.py", line 137, in <module>
    train_model(config)
  File "E:\Fahad\College\BE\Project\MT using Transformers\train.py", line 107, in train_model
    encoder_output = model.encoder(encoder_input, encoder_mask) # (batch_size, context_len, embedding_size)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "E:\Fahad\College\BE\Project\MT using Transformers\model.py", line 246, in encoder
    return self.encoder(src, src_mask)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "E:\Fahad\College\BE\Project\MT using Transformers\model.py", line 240, in encoder
    src = self.input_embedding(src)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\fahad\AppData\Roaming\Python\Python311\site-packages\torch\nn\modules\module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\fahad\AppData\Roaming\Python\Python311\site-packages\torch\nn\modules\module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "E:\Fahad\College\BE\Project\MT using Transformers\model.py", line 18, in forward
    return self.embedding(x) * math.sqrt(self.embedding_dim)
           ^^^^^^^^^^^^^^^^^
  File "C:\Users\fahad\AppData\Roaming\Python\Python311\site-packages\torch\nn\modules\module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\fahad\AppData\Roaming\Python\Python311\site-packages\torch\nn\modules\module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\fahad\AppData\Roaming\Python\Python311\site-packages\torch\nn\modules\sparse.py", line 163, in forward
    return F.embedding(
           ^^^^^^^^^^^^
  File "C:\Users\fahad\AppData\Roaming\Python\Python311\site-packages\torch\nn\functional.py", line 2237, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.FloatTensor instead (while checking arguments for embedding)

Solution

  • There is no way to know where the issue lies unless you share your "model.py" file. But from my experence, you could try tracking all tensor operations inside the encoder's forward function of your model.Tensor datatypes can be converted automatically in most tensor operations. I bet src tensor did some calculations with a float tensor unexpectedly in your case.