Search code examples
pythontensorflowmachine-learningkeras

what does idx * self.batch_size do?


I'm pretty new to doing machine learning, and I stumbled upon Keras's Sequence page and was trying to understand the code block here:

from skimage.io import imread
from skimage.transform import resize
import numpy as np
import math

# Here, `x_set` is list of path to the images
# and `y_set` are the associated classes.

class CIFAR10Sequence(tf.keras.utils.Sequence):

    def __init__(self, x_set, y_set, batch_size):
        self.x, self.y = x_set, y_set
        self.batch_size = batch_size

    def __len__(self):
        return math.ceil(len(self.x) / self.batch_size)

    def __getitem__(self, idx):
        low = idx * self.batch_size
        # Cap upper bound at array length; the last batch may be smaller
        # if the total number of items is not a multiple of batch size.
        high = min(low + self.batch_size, len(self.x))
        batch_x = self.x[low:high]
        batch_y = self.y[low:high]

        return np.array([
            resize(imread(file_name), (200, 200))
               for file_name in batch_x]), np.array(batch_y)

I got stuck at __getitem__ and have no clue how low and high works, can anybody help explain to me what's happening in __getitem__? Why is low = idx * self.batch_size? Why is high = min(low + self.batch_size, len(self.x))?


Solution

  • The variable idx represents the index, and the function def __getitem__(self, idx): is used to fetch data from your CIFAR10Sequence dataset. When loading data from a dataset, the dataloader generates the index (idx) to sample a batch of data.

    Now, let's delve into the code within the def __getitem__(self, idx): function:

    low = idx * self.batch_size
    high = min(low + self.batch_size, len(self.x))
    batch_x = self.x[low:high]
    batch_y = self.y[low:high]
    

    This function aims to sample a batch of data from your dataset within the interval defined by low and high. For example, suppose your dataset contains 420 data. If idx = 42, and the batch size self.batch_size = 8, the def __getitem__(self, idx): function will sample data from 42×8 to 42×8+8.

    Now, let's see why the min() function is used. Consider another example: If idx = 53, 53×8 is greater than the total number of data (420). In this case, it would attempt to sample non-existent data. Therefore, through using min(low + self.batch_size, len(self.x)), this code ensures that the def __getitem__(self, idx): function does not sample data beyond the bounds of the dataset, preventing errors caused by trying to access non-existent data.