Search code examples
python-3.xpytorchneural-networktensor

Pytorch does all training data requires gradient if I do mini-batching?


Lets say I have a function that creates mini batches using bootstrapping from a set of positive and negative labeled elements:

def get_bootstrap_batch(positive_examples, negative_examples, idx, shuffle=True):
    # Truncate the length if we are at the last batch to avoid over indexing
    seq_len = min(BATCH_SIZE, len(positive_examples) - 1 - idx)

    # Grab a batch from the positives with index idx
    positives = positive_examples[idx:idx + seq_len]

    # Get a random staring position from the indexes of the negative set for bootstraping
    slice_pos = int(torch.randint(len(negative_examples)-BOOTSTRAP_SIZE, (1,)))

    # Get a bootstrap sized slice from negatives from slice_pos
    negatives = negative_examples[slice_pos:slice_pos+BOOTSTRAP_SIZE, :, :]

    # Concat the obtained data, these require gradient
    data = torch.concat([positives, negatives])
    target = torch.concat([torch.full((positives.shape[0], ), 1, dtype=torch.float), torch.full((negatives.shape[0], ), 0, dtype=torch.float)])

    # Shuffle the bootstrap
    if shuffle:
        p = torch.randperm(positives.size(0))
        data = data[p, :, :]
        target = target[p]

    return data, target.to(device)

And I have a training cycle like this:

net = some_model_class(...)
positive_set_train = torch.tensor(..., requires_grad=True)
negative_set_train = torch.tensor(..., requires_grad=True)
for batch, i in enumerate(range(0, positive_set_train.size(0) - 1, BATCH_SIZE)):
    
    net.train()  # turn on train mode

    data, targets = get_bootstrap_batch(positive_set_train, negative_set_train, i)
    optimizer.zero_grad()  # zero the gradient buffers
    output = net(data)
    
    loss = criterion(output, targets)
    train_loss += loss.item() * data.size(0)
    loss.backward()

    optimizer.step()  # Does the update

My question is can I turn off requires_grad for positive_set_train and negative_set_train to save memory? Am I understanding it correctly, that in the bootstrap function a new object called data will be created which has requires_grad=True by default and it is the only thing that is required for the training of the network?


Solution

  • requires_grad=True is only for tensor for which you want to compute the gradients. Typically the values you want to update by gradient descent. In you case you need the gradients of the parameters in net but not in positive_examples, negative_examples or data.

    By default all the tensors you create and all the tensors derived from operation on them have requires_grad=False. However the parameters of a neural network have requires_grad=True.

    Example:

    x = torch.tensor([0.6])
    print(x.requires_grad)  # False
    y = torch.tensor([0.4])
    xy = torch.cat([x[None, :, y[None, :])
    print(xy.requires_grad)  # False
    
    net = torch.nn.Linear(2, 1, bias=False)
    weight = net.weight
    print(weight.requires_grad)  # True
    

    We can go further. A tensor that is not a parameter but derived from an operation involving a tensor with requires_grad=True have requires_grad=True itself (even if other tensors used in the operation have not).

    Example:

    z = net(xy.T)
    print(z.requires_grad) # True
    loss = torch.mean((z - torch.tensor([0.5]))**2)  # MSE
    print(loss.requires_grad) # True
    

    Ok but what does tensor.requires_grad=True means precisely (for any tensor tensor) ?

    It means that after calling loss.backward(), if tensor is used in computation of loss, the gradients d loss / d tensor will be computed. Then, if tensor is not the result of an operation itself (typically if it is a parameter), these gradients will be store in tensor.grad. We call these tensor the "leaves".

    Example:

    loss.backward()
    # Intermediate tensor:
    print(z.is_leaf)  # False
    print(z.grad)  # UserWarning: `z` is not a leaf
    # Leaf that requires grad (parameter):
    print(weight.is_leaf)  # True
    print(weight.grad)  # You have two values: one for each parameter. You can make gradient descent!
    # Leaf that doesn't requires grad (raw data):
    print(x.is_leaf)  # True
    print(x.grad)  # None (no warning/error because it is a leaf but no value because it doesn't require gradient)
    

    Finally, the default behavior is exactly what you want in most cases for training a neural network: only the parameters keep their gradients after a back propagation and the minimum number of gradients needed for chain rules are actually computed by pytorch!

    To put it in a nutshell: you don't need to change anything from the default behavior if you want to optimize all the parameters of your neural network and only them.