I have a PyTorch DataLoader and want to retrieve the Dataset object that the loader wraps around. If this is possible, how? Or does the dataset object only exist for pre-loaded datasets on torch?
The end goal is to easily integrate data in dataloader format into code setup for a dataset format (e.g. CIFAR10).
Where in the original code there is:
from torchvision import transforms, datasets
from typing import *
import torch
import os
from torch.utils.data import Dataset
def get_dataset(dataset, split):
if dataset == "CIFAR10"
return _cifar10(split)
def _cifar10(split: str) -> Dataset:
if split == "train":
return datasets.CIFAR10("./dataset_cache", train=True, download=True)
dataset = get_dataset("CIFAR10", train)
for i in range(len(dataset)):
...
I have tried importing the whole dataset at once:
from torchvision import transforms, datasets
from typing import *
import torch
import os
from torch.utils.data import Dataset
def get_dataset(dataset, split):
if dataset == "CIFAR10"
return _cifar10(split)
elif dataset == "mydataset"
return _mydataset(split)
def _mydataset(split: str) -> Dataset:
files = [file for file in os.listdir(database_directory + '/' + split)]
total_num_images = 0
for file in files:
number_images = len([name for name in os.listdir(database_directory +
'/' + split + '/' + file)])
total_num_images += number_images
if split == "train":
mydataset = torch.utils.data.DataLoader(
datasets.ImageFolder(dataset_directory + '/train'),batch_size=total_num_images)
return mydataset
dataset = get_dataset("mydataset", train)
for i in range(len(dataset)):
...
But this returns the error 'DataLoader' object is not subscriptable.
You can access the dataset
attribute on data.DataLoader
to get its underlying data.Dataset
object. As seen in the source code here.