Search code examples
pytorchcomputer-visiondatasettorchvisiondata-augmentation

Increasing instances of a class with Data Augmentation


I am working with some classes of the Charades Dataset https://prior.allenai.org/projects/charades to detect indoor actions.

The structure of my dataset is as follows:

enter image description here

Where:

  • c025, c137 and c142 are actions;
  • XR436 has frames result of splitting a video where users are performing action c025 and the same for X3803, ... There is a total of 250 folders.
  • RI495 has frames result of splitting a video where users are performing action c137 and the same for DI402, ... There is a total of 40 folders.
  • TUCK3 has frames result of splitting a video where users are performing action c142 and the same for the rest. There is a total of 260 folders.

As you can see, the instances of class c137 are quite unbalanced with regard to class c025 and c142. Thus, i would like to increase the number of instances of this class using data augmentation. The idea is creating twin folders with certain transformations. For example, creating A4DID folder as a twin of RI495 with Equalization over each of the frames, A4456 folder as a twin of RI495 in GrayScale, ARTI3 as a twin of DI402 with rotation over the frames, etc. The pattern of transformations can be the same for every folder or not. Just interesting in augmenting the number of instances.

Do you know how to proceed? I am using Pytorch and I tried with torchvision.transforms and DataLoader from torch.utils.data but I have not achieved the result that I am looking for. Any idea on how to proceed?

PS: Undersampling of c025 and c142 is not an option, due to the classifier is not able to learn well with such limited amount of examples.

Thank you in advance


Solution

  • A few thoughts:

    1. Standard practice is to use transforms dynamically; that is, each time a data example is loaded, a compose or sequential set of transform operations are applied with random parameter settings. Thus, each time the datum is loaded, the resulting x (inputs) are different. This can be achieved by defining a stack of transforms to apply to each data example as it is loaded in a pytorch dataset object (see here). This helps provide data augmentation.

    2. Class imbalance is a somewhat different issue, and is generally solved by either a.) oversampling (this is acceptable if using the above transform solution because the oversampled examples will have different transforms applied) or b.) over-weighting of these examples in the loss calculation. Of course, neither approach can account for the risk of receiving an out-of-distribution testing example which is higher the fewer and less diverse examples you have for a given class. The former can be acheived by defining a custom Sampler object that yields examples from your dataset in a class-balanced manner. The latter can be achieved by passing weights to the loss function (many pytorch loss functions such as CrossEntropyLoss already support weights).