Search code examples
pythontensorflow2.0tensorflow-datasetsresampling

tf.data.experimental.sample_from_datasets not sampling as expected


The documentation seems to be bare bone and the example given in their standard TF tutorial not highlighting a behavior I see. Lets say you have an imbalanced dataset of 1 and 0 (pos and neg), and you want to sample at weights [0.5, 0.5], such that you see the positives more frequently. You would do this:

pos_ds = tf.data.Dataset.from_tensor_slices(np.ones(shape=(16, 1)))
neg_ds = tf.data.Dataset.from_tensor_slices(np.zeros(shape=(128, 1)))

resampled_ds = tf.data.experimental.sample_from_datasets([pos_ds, neg_ds], weights=[0.5, 0.5])

And if I want to see how the pos and neg are distributed as I go through the dataset:

xs = []
for x in resampled_ds:
  xs.append(int(x.numpy()[0]))

xs = np.array(xs)
print(xs)

np.bincount(xs)

I see this:

[1 1 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 1 0 0 1 1 1 0 1 0 0 1 0 0 0 0 1 1 0 0 1
 0 1 0 1 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]

array([128,  16])

There are 128 negatives and 16 positives. If I use this as my train_ds, it will be equivalent to no sampling done, and worse, the negatives are no longer uniformly distributed across the steps / epoch. I am guessing that the 0.5 sampling is happening in the beginning and once it "run out" of 1s, it just started sampling the zeros only. It clearly doesn't do sampling with replacement for the 1s. I think the 1s and 0s will only be 0.5/0.5 if you stop after all the 1s are sampled.

It looks like this is the behavior but it isn't the only sensible one. I want to sample the positives multiple times (i.e. sampling with replacement) in 1 epoch, with approx equal amount of pos and negs, is there any option for this API? Also, I have data augmentation so the positives are actually not the same when trained.


Solution

  • Actually, I also found the solution is right there on that TF tutorial imbalanced_data.ipynb (i totally missed this one in my own notebook).

    pos_ds = pos_ds.shuffle(BUFFER_SIZE).repeat()
    neg_ds = neg_ds.shuffle(BUFFER_SIZE).repeat()
    
    resampled_ds = tf.data.experimental.sample_from_datasets([pos_ds, neg_ds], weights=[0.5, 0.5])
    

    The tutorial further suggest a heuristic to set the resampled_steps_per_epoch.

    However, the shuffle + repeat, is still not equivalent to a true sampling with replacement for the minority class. A repeat() follow by a shuffle() may be do it. I can update this by trying both ways.