Search code examples
pythonneural-networkpytorchmnist

pytorch MNIST neural network produces several non-zero outputs


I tried to do a neural network that operates on MNIST data set. I was mostly following the pytorch.nn tutorial. As a result, i got a model that learns, but there's something wrong with the process or with the model itself. Instead of one active neuron at the output, i recieve multiple ones.

Here's the model itself:

model = nn.Sequential(
    nn.Linear(784, 64),
    nn.ReLU(),
    nn.Linear(64, 10),
    nn.ReLU(),
)

And here's the training process:

loss_func = nn.CrossEntropyLoss()
opt = optim.SGD(model.parameters(), lr=lr)

for epoch in range(epochs):
    model.train()
    for xbt, ybt in train_dl:
        pred = model(xbt)
        loss = loss_func(pred, ybt)
        opt.zero_grad()
        loss.backward()
        opt.step()
        

    model.eval()
    # Validation
    if epoch % 10 == 0:
        with torch.no_grad():
            losses, nums = zip(
                *[(loss_func(model(xbv), ybv), len(xbv)) for xbv, ybv in valid_dl]
            )
        val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)

        print(epoch, val_loss)

Here's average loss each 10th epoch:

0 0.13384412774592638
10 0.0900113809091039
20 0.09795805384699234
30 0.10341344920364791
40 0.10804545368137551

And thats how result of applying the model to the validation set looks like:

[[ 0.         0.         0.        ... 28.436266   0.         5.001435 ]
 [ 7.3331523 12.666427  31.898096  ...  0.         0.         0.       ]
 [ 0.        18.116354   8.049953  ...  4.330721   0.         0.       ]
 ...
 [ 8.504517   0.         6.302228  ...  0.         0.         0.       ]
 [ 1.7339934  0.         0.        ...  0.         2.1565871  0.       ]
 [45.750134   0.         6.2685804 ...  2.247082   0.         0.       ]]
 Shape: (9984, 10)

I tried changing learning speed, model layers, amount of epochs, but nothing seems to work.


Solution

  • You have 10 neurons with ReLU in the last layers and yes all the neurons will fire/activated. In this case every neuron applies a ReLu function on the output of linear activation. ie ReLu(w.x+b). There are 10 such neurons and all of them will give out certain output based on its input, and yes all of them get fired/activated. The way you infer an output from this is by taking the class corresponding to the neuron which has the hugest activation (using np.argmax or torch.max).