Search code examples
pytorch

Convolutional Neural Network in PyTorch with custom data


I am trying create a CNN for classification of three dancers with skeleton data in PyTorch. The Dataset is split into 3000 pieces each with 50 frames and 72 joint position data. I want to interpret the data like an image and therefore want to use a CNN for classification, but I am not sure how to use the Dataloader. In this link there is an example on how to train a CNN for classification, but the Dataloader uses a pre-configured dataset and I am not sure how I should format my custom data in the argument "trainset" in the command:

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,shuffle=False, num_workers=1)

My input data basically has 3000 data points each which has 50 frames of 72 joint positions. And my labels are a vector of length 3000 each which can assume the output 0,1,2 for three different dancers. I hope someone can help.


Solution

  • You need to create an class that inherits torch.utils.data.Dataset. In short, the class needs to contain an __init__ function that specifies necessary attributes, an __getitem__ function that returns a data point (optionally with its label) at an indexed position, and a __len__ function to specify the number of data points in your dataset.

    Here is a template of how to create a customized dataset.

    from torch.utils import data
    class MyDataset(data.Dataset):
        def __init__(self, root):
            self.root = root
            self.dset = # load your data from the root here
        
        def __getitem__(self, index):
            return self.dset[index]
    
        def __len__(self):
            return len(self.dset)
    

    You may also like to refer to this real implementation, where I customize a dataset that uses GTA5 images for a semantic segmentation task.

    Lastly, just treat it as a pre-configured dataset. For example,

    trainset = MyDataset(root)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False, num_workers=1)