Search code examples
pythontensorflowmachine-learningdeep-learningprefetch

Using tf.data.Dataset prefetch makes model performance overfit?


I'm trying to train a simple LRCN model with some sequential image dataset in Tensorflow 2.5.0. Training performance was fine like increasing to 0.9x training & validation accuracy both in first 5 epochs and train & validation loss kept decreasing during the training.

Then, I've tried to optimize data pipeline with using prefetch(). The dataset I use is sequential images (.png) that titles and information are written in .csv file. So I made the data generator like below :

def setData(data):
X, y = [], []

name = data.loc['fileName'].values.tolist()[0]
info1 = data.loc['info1'].values.tolist()[0]
info2 = data.loc['info2'].values.tolist()[0]
info3 = data.loc['info3'].values.tolist()[0]

if os.path.isfile(filepath + name) == False:
    print('No file for img')
    return

try:
    img = np.load(filepath + fName)
except:
    print(name)  

if info1 in info_list:  
    X.append(img)

    if info2 == 'True':
        y.append(0)

    else:
        y.append(1)

X = np.array(X)
X = np.reshape(X, (3, 128, 128, 1)).astype(np.float64)
y = np_utils.to_categorical(y, num_classes = 2)
y = np.reshape(y, (2)).astype(np.float64)

return X, y

And I added the data generator load function like this :

def generatedata(i):
    i = i.numpy()
    X_batch, y_batch = setData(pd.DataFrame(traindata.iloc[i]))

    return X_batch, y_batch

Finally, I prefetched dataset using map

z = list(range(len(traindata[])))
trainDataset = tf.data.Dataset.from_generator(lambda: z, tf.uint8)
trainDataset = trainDataset.map(lambda i: tf.py_function(func = generatedata,
                                                         inp = [i],
                                                         Tout = [tf.float32, tf.float32]),
                                                         num_parallel_calls = tf.data.AUTOTUNE)

After I applied these steps, training accuracy goes 0.9x in first epoch, 1.0 in the first 3-5 epochs and validation accuracy stays at around 0.6x and validation loss kept growing over x.x.

I believe that the prefetch only changes the data pipeline that do not affect to the model performance so I'm not sure what caused this overfitting(maybe?)-like results. I followed every step of the prefetch step that were denoted at the Tensorflow documentation. Though, since I'm not very familiar with tensorflow, there might be some mistakes.

Is there any line that I missed? Any opinion would be really greatfull. Thanks in advance.


Solution

  • It turns out that the py_function() makes tf.graph stacked over previous results that leads the model to overfitting.

    I've modified the prefetch function to get the generator function and works as it should be. Though I checked the tensorflow documents, I haven't fully alarmed of this situation but found this in tensorflow github page.

    To those who have same problem as me, try to review the library module function carefully.