I am trying to create a function that returns a batch of data (list) every time I call it.
It should be able to repeat for any number of training steps and restart from the beginning after having iterated over the whole dataset (after each epoch).
def generate_batch(X, batch_size):
for i in range(0, len(X), batch_size):
batch = X[i:i+batch_size]
yield batch
X = [
[1, 2],
[4, 0],
[5, 1],
[9, 99],
[9, 1],
[1, 1]]
for step in range(num_training_steps):
x_batch = generate_batch(X, batch_size=2)
print(list(x_batch))
when I print the output of the function, I see that it gets the whole data (X) not a batch:
[[[1, 2], [4, 0]], [[5, 1], [9, 99]], [[9, 1], [1, 1]]]
[[[1, 2], [4, 0]], [[5, 1], [9, 99]], [[9, 1], [1, 1]]]
[[[1, 2], [4, 0]], [[5, 1], [9, 99]], [[9, 1], [1, 1]]]
What is the problem? is this the right way to use yield
?
First of all, if you want to restart from the beginning after the data is over, you will need to wrap the generator function body in an infinite loop, like this:
def generate_batch(X, batch_size):
while 1:
for i in range(0, len(X), batch_size):
batch = X[i:i+batch_size]
yield batch
Then, when you do:
x_batch = generate_batch(X, batch_size=2)
Now x_batch
is a generator. You will need to iterate over it or call next()
on it to get the data one batch at a time. If you just do list(x_batch)
it will iterate and collect all the batches for you into a list. This is not what you want.
What you want is:
gen = generate_batch(X, batch_size=2)
for step in range(num_training_steps):
x_batch = next(gen)
print(x_batch)
Or alternatively, if you need a callable function:
gen = generate_batch(X, batch_size=2)
gen = gen.__next__
for step in range(num_training_steps):
x_batch = gen()
print(x_batch)
Also, you probably want to give the function a different name, like e.g. create_batch_generator()
.