Search code examples
pythontensorflowkerasconfusion-matrix

How to get confusion_matrix y_true in Customdatagenerator


I want to build confusion_matrix but I always got error message

ValueError
Found input variables with inconsistent numbers of samples: [0, 62]
  File "C:\Labbb\inceptionResnetV2\InceptionResnetV2_1.py", line 216, in <module>
    sns.heatmap(confusion_matrix(y_true, y_pred),
ValueError: Found input variables with inconsistent numbers of samples: [0, 62]

How to get y_true in Customdatagenerator?

I try to append y_true in get_data, and use def get_y_true return y_true, but not work

Here is CustomDataGenerator code.

class CustomDataGenerator(Sequence):
    def __init__(self, image_folders, label_folders, dir, dim=(512,512),  batch_size=1,n_classes=7,n_channels=8,shuffle=True):
        self.image_folders = image_folders
        ...
        self.image_paths = []
        self.label_paths = []
        self.y_true = []
        self.on_epoch_end()

    def __len__(self):
        return int(np.ceil(len(self.image_paths) / self.batch_size))  

    def __getitem__(self, index):
        batch_image_paths = self.image_paths[index * self.batch_size: (index + 1) * self.batch_size]
        batch_label_paths = self.label_paths[index * self.batch_size: (index + 1) * self.batch_size]
        batch = zip(batch_image_paths, batch_label_paths)
        return self.get_data(batch)

    def on_epoch_end(self):
        self.image_paths = []
        self.label_paths = []
        for folder in self.image_folders:
            image_folder_path = os.path.join(self.dir, folder)
            image_files = os.listdir(image_folder_path)
            for file_name in image_files:
                self.image_paths.append(os.path.join(image_folder_path, file_name))
        for folder in self.label_folders:
            label_folder_path = os.path.join(self.dir, folder)
            label_files = os.listdir(label_folder_path)
            for file_name in label_files:
                self.label_paths.append(os.path.join(label_folder_path, file_name))
        if self.shuffle:
            np.random.shuffle(self.image_paths)
            np.random.shuffle(self.label_paths)

    def get_data(self, batch):
        X = np.empty((self.batch_size, *self.dim, self.n_channels))
        y = np.empty((self.batch_size, self.n_classes))
        y_true = []

        for i, (image_path, label_path) in enumerate(batch):
            image = np.load(image_path)
            with open(label_path, 'r') as f:
                line = f.readline().strip()
                filepath, label = line.rsplit(' ', 1)
                label = int(label)
                y_true.append(label)
            label_one_hot = to_categorical(label, num_classes=self.n_classes)

            X[i,] = image
            y[i,] = label_one_hot

        return X, y
    
    def get_y_true(self):
        return self.y_true

Here is get y_true and y_pred ,and build confusion_matrix

train_datagen = CustomDataGenerator(image_folders, label_folders, train_dir, **params, shuffle = True)
val_datagen = CustomDataGenerator(image_folders, label_folders, valid_dir, **params, shuffle = True)

y_true = CustomDataGenerator.get_y_true(val_datagen)
Y_pred = model.predict(val_datagen)
y_pred = np.argmax(Y_pred, axis=1) 
sns.heatmap(confusion_matrix(y_true, y_pred),annot=True, fmt="d", cmap='Greens',ax = ax)

Solution

  • There are a few things I would like to comment on.
    As for your initial question, y_true is empty: self.y_true=[] in the __init__() of the class. It never gets filled. In get_data(..) is a y_true, but it is not self.y_true, so it doesn't get stored and is lost at the end of the method. The error with the shapes [0, 62] shows this too, self.y_true here has shape 0, so it's empty.

    A few tips on code quality here. on_epoch_end(..) does too much. You do not need to re-write the image paths every epoch. Do the initialization in another method, and only do the shuffling in on_epoch_end().
    You should also be careful with the dir parameter in __init__(). dir is a build-in function of python, and you should not overwrite them unless you know what you're doing. That's why it's highlighted orange in your code here. In this specific code it's doing no harm, but just be aware of it.
    Instead of calling y_true = CustomDataGenerator.get_y_true(val_datagen) you could do y_true = val_datagen.get_y_true(). It works the same and is (in my opinion) clearer. Frankly, I never saw your notation before.

    Last point, your example is not reproducible. I tried to run your code, but you seem to have omitted some parts of the code and I got errors and had to guess to fix them. It really helps when you commit the whole (relevant) code and comment on it.