Search code examples
pythonneural-networkpytorchtyping

RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'mat2' how to fix it?



import torch.nn as nn 
import torch 
import torch.optim as optim
import itertools

class net1(nn.Module):
    def __init__(self):
        super(net1,self).__init__()

        self.pipe = nn.Sequential(
            nn.Linear(10,10),
            nn.ReLU()
        )

    def forward(self,x):
        return self.pipe(x.long())

class net2(nn.Module):
    def __init__(self):
        super(net2,self).__init__()

        self.pipe = nn.Sequential(
            nn.Linear(10,20),
            nn.ReLU(),
            nn.Linear(20,10)
        )

    def forward(self,x):
        return self.pipe(x.long())



netFIRST = net1()
netSECOND = net2()

learning_rate = 0.001

opt = optim.Adam(itertools.chain(netFIRST.parameters(),netSECOND.parameters()), lr=learning_rate)

epochs = 1000

x = torch.tensor([1,2,3,4,5,6,7,8,9,10],dtype=torch.long)
y = torch.tensor([10,9,8,7,6,5,4,3,2,1],dtype=torch.long)


for epoch in range(epochs):
    opt.zero_grad()

    prediction = netSECOND(netFIRST(x))
    loss = (y.long() - prediction)**2
    loss.backward()

    print(loss)
    print(prediction)
    opt.step()

error:

line 49, in prediction = netSECOND(netFIRST(x))

line 1371, in linear; output = input.matmul(weight.t())

RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'mat2'

I don't really see what I'm doing wrong. I have tried to turn everything in a Long in every possible way. I don't really get the way typing works for pytorch. Last time I tried something with just one layer and it forced me to use type int. Could someone explain how the typing is established in pytorch and how to prevent and fix errors like this?? A lot I mean an awful lot of thanks in advance, this problem really bothers me and I can't seem to fix it no matter what I try.


Solution

  • The weights are Floats, the inputs are Longs. This is not allowed. In fact, I don't think torch supports anything else than Floats in neural networks.

    If you remove all calls to long, and define your input as floats, it will work (it does, I tried).

    (You will then get another unrelated error: you need to sum your loss)