I'm pretty new to doing machine learning, and I stumbled upon Keras's Sequence page and was trying to understand the code block here:
from skimage.io import imread
from skimage.transform import resize
import numpy as np
import math
# Here, `x_set` is list of path to the images
# and `y_set` are the associated classes.
class CIFAR10Sequence(tf.keras.utils.Sequence):
def __init__(self, x_set, y_set, batch_size):
self.x, self.y = x_set, y_set
self.batch_size = batch_size
def __len__(self):
return math.ceil(len(self.x) / self.batch_size)
def __getitem__(self, idx):
low = idx * self.batch_size
# Cap upper bound at array length; the last batch may be smaller
# if the total number of items is not a multiple of batch size.
high = min(low + self.batch_size, len(self.x))
batch_x = self.x[low:high]
batch_y = self.y[low:high]
return np.array([
resize(imread(file_name), (200, 200))
for file_name in batch_x]), np.array(batch_y)
I got stuck at __getitem__
and have no clue how low
and high
works, can anybody help explain to me what's happening in __getitem__
? Why is low = idx * self.batch_size
? Why is high = min(low + self.batch_size, len(self.x))
?
The variable idx
represents the index, and the function def __getitem__(self, idx):
is used to fetch data from your CIFAR10Sequence
dataset. When loading data from a dataset, the dataloader generates the index (idx
) to sample a batch of data.
Now, let's delve into the code within the def __getitem__(self, idx):
function:
low = idx * self.batch_size
high = min(low + self.batch_size, len(self.x))
batch_x = self.x[low:high]
batch_y = self.y[low:high]
This function aims to sample a batch of data from your dataset within the interval defined by low
and high
. For example, suppose your dataset contains 420 data. If idx = 42
, and the batch size self.batch_size = 8
, the def __getitem__(self, idx):
function will sample data from 42×8 to 42×8+8.
Now, let's see why the min()
function is used. Consider another example: If idx = 53
, 53×8 is greater than the total number of data (420). In this case, it would attempt to sample non-existent data. Therefore, through using min(low + self.batch_size, len(self.x))
, this code ensures that the def __getitem__(self, idx):
function does not sample data beyond the bounds of the dataset, preventing errors caused by trying to access non-existent data.