Search code examples
pythonopencvpytorchface-recognition

ValueError: sampler option is mutually exclusive with shuffle pytorch


i'm working on face recognition project using pytorch and mtcnn and after trained my training dataset , now i want to make prediction on test data set

this my trained code

optimizer = optim.Adam(resnet.parameters(), lr=0.001)
scheduler = MultiStepLR(optimizer, [5, 10])

trans = transforms.Compose([
   np.float32,
   transforms.ToTensor(),
   fixed_image_standardization
])
dataset = datasets.ImageFolder(data_dir, transform=trans)
img_inds = np.arange(len(dataset))
np.random.shuffle(img_inds)
train_inds = img_inds[:int(0.8 * len(img_inds))]
val_inds = img_inds[int(0.8 * len(img_inds)):]

train_loader = DataLoader(
   dataset,
   num_workers=workers,
   batch_size=batch_size,
   sampler=SubsetRandomSampler(train_inds)
)
val_loader = DataLoader(
   dataset,
   shuffle=True,
   num_workers=workers,
   batch_size=batch_size,
   sampler=SubsetRandomSampler(val_inds)
)

and if remove sampler=SubsetRandomSampler(val_inds) and put val_inds instead it will rise this error

val_inds ^ SyntaxError: positional argument follows keyword argument

i want to make prediction (select randomly from test data set) in pytorch?thats why i should use shuffle=True i followed this repo facenet-pytorch


Solution

  • TLDR; Remove shuffle=True in this case as SubsetRandomSampler shuffles data already.

    What torch.utils.data.SubsetRandomSampler does (please consult documentation when in doubt) is it will take a list of indices and return their permutation.

    In your case you have indices corresponding to training (those are indices of elements in training Dataset) and validation.

    Let's assume those look like that:

    train_indices = [0, 2, 3, 4, 5, 6, 9, 10, 12, 13, 15]
    val_indices = [1, 7, 8, 11, 14]
    

    During each pass SubsetRandomSampler will return one number from those lists at random and those will be randomized again after all of them were returned (__iter__ will be called again).

    So SubsetRandomSampler might return something like this for val_indices (analogously for train_indices):

    val_indices = [1, 8, 11, 7, 14]  # Epoch 1
    val_indices = [11, 7, 8, 14, 1]  # Epoch 2
    val_indices = [7, 1, 14, 8, 11]  # Epoch 3
    

    Now each of those numbers are an index to your original dataset. Please note validation is shuffled this way and so is train without using shuffle=True. Those indices do not overlap so data is splitted correctly.

    Additional info

    • shuffle uses torch.utils.data.RandomSampler under the hood if shuffle=True is specified, see source code. This in turn is equivalent to using torch.utils.data.SubsetRandomSampler with all indices (np.arange(len(datatest))) specified.
    • you don't have to pre-shuffle np.random.shuffle(img_inds) as indices will be shuffled during each pass anyway
    • don't use numpy if torch provides the same functionality. There is torch.arange, mixing both libraries is almost never necessary.

    Inference

    Single image

    Just pass it through your network an get output, e.g.:

    module.eval()
    with torch.no_grad():
        output = module(dataset[5380])
    

    First line puts model in evaluation mode (changes behaviour of some layer), context manager turns off gradient (as it's not needed for predictions). Those are almost always used when "checking neural network output".

    Checking validation dataset

    Something along those lines, notice the same ideas applied as for single image:

    module.eval()
    
    total_batches = 0
    batch_accuracy = 0
    for images, labels in val_loader:
        total_batches += 1
        with torch.no_grad():
            output = module(images)
            # In case it outputs logits without activation
            # If it outputs activation you may have to use argmax or > 0.5 for binary case
            # Item gets float from torch.tensor
            batch_accuracy += torch.mean(labels == (output > 0.0)).item()
    
    print("Overall accuracy: {}".format(batch_accuracy / total_batches))
    

    Other cases

    Please see some beginners guides or tutorials and understand those concepts as StackOverflow is not a place to re-do this work (rather concrete and small questions), thanks.