Search code examples
pythonpython-3.xtensorflowtf.data.datasettf.dataset

tf.data.datasets set each batch (prefetch)


I am looking for help thinking through this.

I have a function (that is not a generator) that will give me any number of samples. Let's say that getting all the data I want to train (1000 samples) can't fit into memory. So I want to call this function 10 times to get smaller number of samples that fit into memory.

This is a dummy example for simplicity.

def get_samples(num_samples: int, random_seed=0):
    np.random.seed(random_seed)
    x = np.random.randint(0,100, num_samples)
    y = np.random.randint(0,2, num_samples)
    return np.array(list(zip(x,y))

Again lets say get_samples(1000,0) won't fit into memory.

So in theory I am looking for something like this:

batch_size = 100
total_num_samples = 1000
batches = []
for i in range(total_num_samples//batch_size):
    batches.append(get_samples(batch_size, i))

But this still loads everything into memory.

Again this function is a dummy representation and the real one is already defined and not a generator.

In tf land. I was hoping that:

tf.data.Dataset.batch[0] would equal to the output of get_data(100,0)
tf.data.Dataset.batch[1] would equal to the output of get_data(100,1)
tf.data.Dataset.batch[2] would equal to the output of get_data(100,2)
...
tf.data.Dataset.batch[9] would equal to the output of get_data(100,9)

I understand that I can use tf.data.Datasets with a generator (and I think you can set a generator per batch). But the function I have gives more than a single sample. The set up is too expensive to set it up for a every single sample.

I was wanting to use tf.data.Dataset.prefetch() to run the get_batch function on every batch. And of course, it would call the get_batch with the same parameters on every epoch.

Sorry if the explaination is convoluted. Trying my best to describe the problem.

Anyone have any ideas?


Solution

  • This what I came up with:

    def simple_static_synthesizer(batch_size, seed=1, verbose=True):
        if verbose:
            print(f"Creating Synthetic Data with seed {seed}")
        rng = np.random.default_rng(seed)
        all_x = []
        all_y = []
        for i in range(batch_size):
            x = np.array(np.concatenate((rng.integers(0,100, 1, dtype=int), rng.integers(0,100, 1, dtype=int), rng.integers(0,100, 1, dtype=int))))
            y = np.array(rng.integers(0,2,1, dtype=int))
            all_x.append(x)
            all_y.append(y)
        return all_x, all_y
    
    def my_generator(total_size, batch_size, seed=0, verbose=True):
        counter = 0
        for i in range(total_size):
            # Regenerate for each batch
            if counter%batch_size == 0: # Regen data for every batch
                x,y = simple_static_synthesizer(batch_size,seed,verbose)
                seed += 1
            yield x[i%batch_size],y[i%batch_size]
            counter += 1
    my_gen = my_generator(10,2,seed=1)
    
    # See values
    for x,y in my_gen:
        print(x,y)
    
    # Call again, this give same answer as above
    my_gen = my_generator(10,2,seed=1)
    for x,y in my_gen:
        print(x,y)
    
    # Dataset with small batches to see if it is doing it correctly
    total_samples = 10
    batch_size = 2
    seed = 5
    
    dataset = tf.data.Dataset.from_generator(
        my_generator,
        args=[total_samples,batch_size,seed],
        output_signature=(
            tf.TensorSpec(shape=(3,), dtype=tf.uint8),
            tf.TensorSpec(shape=(1,), dtype=tf.uint8),
        )
    )
    for i,(x,y) in enumerate(dataset):
        print(x.numpy(),y.numpy())
        if i == 4:
            break # shows first 3 syn calls
    

    Wish we could have notebook answers!