Search code examples
tensorflowkerasnibabel

Loading large sets of data for TensorFlow deep learning


I'm loading data consisting of thousands of MRI images. I'm doing it like this using nibabel to obtain the 3D data arrays from the MRI files:

def get_voxels(path):
    img = nib.load(path)
    data = img.get_fdata()

    return data.copy()


df = pd.read_csv("/home/paths_updated_shuffled_4.csv")
df = df.reset_index()

labels = []
images = []
for index, row in df.iterrows():
    images.append(get_voxels(row['path']))
    labels.append(row['pass'])
labels = np.array(labels)
images = np.array(images)

n = len(df.index)
train_n = int(0.8 * n)
train_images = images[:train_n]
train_labels = labels[:train_n]
validation_n = (n - train_n) // 2
validation_end = train_n + validation_n
validation_images, validation_labels = images[train_n:validation_end], labels[train_n:validation_end]
test_images = images[validation_end:]
test_labels = labels[validation_end:]

train_ds = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
validation_ds = tf.data.Dataset.from_tensor_slices((validation_images, validation_labels))
test_ds = tf.data.Dataset.from_tensor_slices((test_images, test_labels))

As you can see, I'm using tf.data.Dataset.from_tensor_slices. However, I'm running out of memory because of the large number of large files.

Is there a better way to do this in TensorFlow or Keras.


Solution

  • Do as described in 3D image classification from CT scans by Hasib Zunair.

    import nibabel as nib
    import pandas as pd
    import numpy as np
    
    def process_scan(path):
        """Read and resize volume"""
        # Read scan
        volume = read_nifti_file(path)
        # Normalize
        volume = normalize(volume)
        # Resize width, height and depth
        volume = resize_volume(volume)
        return volume
    
    
    df = pd.read_csv("/home/paths_updated_shuffled_4.csv")
    n = len(df.index)
    passing_rows = df.loc[df['pass'] == 1]
    normal_scan_paths = passing_rows['path'].tolist()
    failing_rows = df.loc[df['pass'] == 0]
    abnormal_scan_paths = failing_rows['path'].tolist()
    
    print("Passing MRI scans: " + str(len(normal_scan_paths)))
    print("Failing MRI scans: " + str(len(abnormal_scan_paths)))
    
    # Loading data and preprocessing
    # Read and process the scans.
    # Each scan is resized across height, width, and depth and rescaled.
    abnormal_scans = np.array([process_scan(path) for path in abnormal_scan_paths])
    normal_scans = np.array([process_scan(path) for path in normal_scan_paths])