Search code examples
tensorflowfiltertf.data.dataset

How efficiently filter a specific number of entries and concatenating them in a unique tf.data.Dataset?


I have a huge TFRecord file with more than 4M entries. It is a very unbalanced dataset containing many more entries of some labels and few others - compare to the whole dataset. I want to filter a limited number of entries of some of these labels in order to have a balanced dataset. Below, you can see my attempt, but it takes more than 24 hours to filter 1k from each label (33 different labels).

import tensorflow as tf

tf.compat.as_str(
    bytes_or_text='str', encoding='utf-8'
)
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect()
    print("Device:", tpu.master())
    strategy = tf.distribute.TPUStrategy(tpu)
except:
    strategy = tf.distribute.get_strategy()
print("Number of replicas:", strategy.num_replicas_in_sync)

ignore_order = tf.data.Options()
ignore_order.experimental_deterministic = False 
dataset = tf.data.TFRecordDataset('/test.tfrecord')
dataset = dataset.with_options(ignore_order)
features, feature_lists = detect_schema(dataset)

#Decodings TFRecord serialized data
def decode_data(serialized):
    X, y = tf.io.parse_single_sequence_example(
        serialized,
        context_features=features,
        sequence_features=feature_lists)
    return X['title'], y['subject']

dataset = dataset.map(lambda x: tf.py_function(func=decode_data, inp=[x], Tout=(tf.string, tf.string)))

#Filtering and concatenating the samples

def balanced_dataset(dataset, labels_list, sample_size=1000):
    datasets_list = []
    for label in labels_list:
        #Filtering the chosen labels
        locals()[label] = dataset.filter(lambda x, y: tf.greater(tf.reduce_sum(tf.cast(tf.equal(tf.constant(label, dtype=tf.int64), y), tf.float32)), tf.constant(0.)))
        #appending a limited sample
        datasets_list.append(locals()[label].take(sample_size))
    concat_dataset = datasets_list[0]
    #concatenating the datasets
    for dset in datasets_list[1:]:
        concat_dataset = concat_dataset.concatenate(dset)
    return concat_dataset  

balanced_data = balanced_dataset(tabledataset, labels_list=list(decod_dic.values()), sample_size=1000)

Solution

  • One way to solve this is by using group_by_window method where the window_size would be the sample size of each class (in your case 1k).

    ds = ds.group_by_window(     
        # Use label as key     
        key_func=lambda _, l: l,     
        # Convert each window to a sample_size     
        reduce_func=lambda _, window: window.batch(sample_size),     
        # Use window size as sample_size    
        window_size=sample_size)
    

    This will form batches of single classes of size sample_size. But there is one problem, there will be multiple batches of same class, but you just need one of the batches in each class.

    To solve the above problem, we need to add a count for each of the batches and then filter out count==0, which will fetch the first batch of all the classes.

    Lets define an example:

    labels = np.array(sum([[label]*repeat for label, repeat in zip([0, 1, 2], [100, 200, 15])], []))
    features = np.arange(len(labels))
    np.unique(labels, return_counts=True)
    #(array([0, 1, 2]), array([100, 200,  15]))
    # There are 3 labels chosen for simplicity and each of their counts are shown along.
    sample_size = 15 # we choose to pick sample of 15 from each class
    

    We create a dataset from the above inputs,

    ds = tf.data.Dataset.from_tensor_slices((features, labels))
    

    In the above window function we modify the reduce_func to make the counter, so the batch will have 3 elements (X_batch, y_batch, label_counter) :

    def reduce_func(x, y):
        #class_count[y] += 1
        z = table.lookup(x)
        table.insert(x, z+1)
        return y.batch(sample_size).map(lambda a,b: (a, b, z))
    # Group by window 
    ds = tf.data.Dataset.from_tensor_slices((features, labels)) 
    ds = ds.group_by_window(     
        # Use label as key     
        key_func=lambda _, l: l,     
        # Convert each window to a sample_size     
        reduce_func=reduce_func,     
        # Use window size as sample_size    
        window_size=sample_size)
    

    The counter logic in reduce_func is implemented as a table lookup where the counter needs to be updated and read from a lookup table. Its initialized as shown below:

    n_classes = 3
    keys = tf.range(0,n_classes, dtype=tf.int64)
    vals = tf.zeros_like(keys, dtype=tf.int64)
    table = tf.lookup.experimental.MutableHashTable(key_dtype=tf.int64, 
                                                    value_dtype=tf.int64, 
                                                    default_value=-1)
    table.insert(keys, vals)
    

    Now we filter out the batch where the count==0 and remove the count element to form (X, y) batch pairs:

    ds = ds.filter(lambda x, y, count: count==0)
    ds = ds.map(lambda x, y, count: (x, y))
    

    Output,

    for x, y in ds:
        print(x.numpy(), y.numpy())
    [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14] [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
    [100 101 102 103 104 105 106 107 108 109 110 111 112 113 114] [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
    [300 301 302 303 304 305 306 307 308 309 310 311 312 313 314] [2 2 2 2 2 2 2 2 2 2 2 2 2 2 2]