Search code examples
tensorflowgpuramtensorflow-datasets

how to use predict_on_batch to avoid out of GPU memory error with datagenerator


I have a Keras model which consists of two parts (left and right), both working generally independently (due to a practical scenario), but these parts exchange some latent data generated by the model in intermediate steps. I want to compress this latent data using autoencoders. Therefore, I introduced autoencoders into the model as additional submodels. For the training of these autoencoders, I generated submodels of the main model using calls similar to

    submodel = tf.keras.Model(inputs=[model.input], outputs=[model.get_layer(submodel_name).output])

This is working fine. However, I now generate the latent data, which I need to train my autoencoders using an L2 loss, by predicting the entire dataset:

    train = submodel.predict(train_ds)        

However, because the original dataset is not that small, and because the submodels output dimension is rather large, I run out of GPU memory when running this line for one of my submodels. The entire process of training the autoencoder is the following:

 submodel = tf.keras.Model(inputs=[model.input], outputs=[model.get_layer(submodel_name).output])
    AE_name = 'AE_' + submodel_name
    AE = model.get_layer(AE_name)
    
    train = submodel.predict(train_ds)        
    valid = submodel.predict(valid_ds)
    
    AE.bypass = False
    AE.compile(loss='mse', run_eagerly=False)

    AE.fit(x=train,y=train, validation_data=(valid,valid),
                    epochs=epochs, verbose=0, callbacks = [callbacks[0]])

Initially, to generate the data, the autoencoders were set to bypass (1:1 mapping), so that I use get the correct latent data of the original model without autoencoders.

How can I split the prediction and fit into smaller steps so that not that much GPU Ram is needed? My issue is, that I am rather unfamiliar with the tensorflow Datagenerator class used to generate the training and validation data. My attempts so far failed.

To give context, the datagenerator, which is used to create train_ds and valid_ds, uses the following functions to encode and (subsequently) fetch the data:

    def fetch(self):
        dataset = tf.data.TFRecordDataset(self.tfr).map(self._decode,
                                                        num_parallel_calls=tf.data.experimental.AUTOTUNE)
        if self.mode == "train":
            dataset = dataset.shuffle(2000, reshuffle_each_iteration=False) 
            train_dataset = dataset.batch(1, drop_remainder=True)#dataset.batch(self.batch_size, drop_remainder=True)
            train_dataset = train_dataset.prefetch(tf.data.experimental.AUTOTUNE)
            return train_dataset
        
        if self.mode == "valid":
            valid_dataset = dataset.batch(1, drop_remainder=False)
            valid_dataset = valid_dataset.prefetch(tf.data.experimental.AUTOTUNE) 
            return valid_dataset
        
        else:
            dataset = dataset.batch(1, drop_remainder=True)
            dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
            return dataset

    def _encode(self, mode):                 
        writer = tf.io.TFRecordWriter(self.tfr)
        
        if self.mode != "test":
            mix_filenames = glob.glob(os.path.join(self.wav_dir, "*mixed*.wav"))
            target_filenames = glob.glob(os.path.join(self.wav_dir, "*target*.wav"))
            sys.stdout.flush()  
            
            for mix_filename, target_filename in tqdm(zip(mix_filenames, 
                                                          target_filenames), total = len(mix_filenames)):
                mix, _ = librosa.load(mix_filename, self.sample_rate, mono = False)
                clean, _ = librosa.load(target_filename, self.sample_rate, mono = False)
    
                def write(a, b):
                    example = tf.train.Example(
                        features=tf.train.Features(
                            feature={
                                "noisy_left" : self._float_list_feature(mix[0, a:b]),
                                "noisy_right": self._float_list_feature(mix[1, a:b]),
                                "clean_left" : self._float_list_feature(clean[0, a:b]),
                                "clean_right": self._float_list_feature(clean[1, a:b])}))
                    
                    writer.write(example.SerializeToString())
                
                now_length = mix.shape[-1]
                target_length = int(self.duration * self.sample_rate)
    
                if now_length < target_length:
                    continue 
                
                stride = int(self.duration * self.sample_rate)
                for i in range(0, now_length - target_length, stride):
                    write(i, i + target_length)

I tried to iterator through train_ds and to call the model (not predict) on each element independently, the idea being that I train the autoencoder with each individual prediction for one epoch. However, I noticed that model(SingleElement) yields a very slightly (order of 10⁻5) difference to the model output when I call model.predict. The reason might be some normalization layer used (there is no dropout at work, I checked). Because I do not want to risk missing these model details (I got the code from a colleague), I'd prefer to avoid this approach and to somehow predict on batches. However, this I was not able to get to work.


Solution

  • I resolved my issue by switching to the CPU for the RAM intensive task. It is considerably slower, but in my case not too slow, so I can spare the time. I switched to CPU using

    with tf.device('/cpu:0'):
        TrainModelOnLargeData