For a tensorflow dataset iterator (tf.data.Iterator), what is the best way of skipping over the first X batches, but ONLY in the first iteration, and not for subsequent iterations when repeat() is specified)?
I tried the following but it didn't work:
import tensorflow as tf
import pandas as pd
from pyspark.sql import SparkSession
spark = SparkSession.builder.master('local[*]').config("spark.jars",'some/path/spark-tensorflow-connector_2.11-1.10.0.jar').getOrCreate()
df = pd.DataFrame({'x': range(10), 'y': [i*2 for i in range(10)]})
df = spark.createDataFrame(df)
df.write.format('tfrecords').option('recordType', 'Example').mode("overwrite").save('testdata')
def parse_function(proto):
feature_description = {
'x': tf.FixedLenFeature([], tf.int64),
'y': tf.FixedLenFeature([], tf.int64)
}
parsed_features = tf.parse_single_example(proto, feature_description)
x = parsed_features['x']
y = parsed_features['y']
return {'x': x, 'y': y}
def load_data(filename_pattern, parse_function, batch_size=200, skip_batches=0):
files = tf.data.Dataset.list_files(file_pattern=filename_pattern, shuffle=False)
dataset = tf.data.TFRecordDataset(files)
dataset = dataset.repeat()
dataset = dataset.map(parse_function)
dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))
dataset = dataset.prefetch(2)
# Create an iterator
iterator = dataset.make_one_shot_iterator()
data = iterator.get_next()
with tf.Session() as sess:
for i in range(skip_batches):
sess.run(data)
return data
# skip first three batches
data = load_data('testdata/part-*', parse_function, batch_size=2, skip_batches=3)
sess = tf.Session()
for i in range(3):
print(sess.run(data))
Expected/desired:
{'y': array([12, 14]), 'x': array([6, 7])}
{'y': array([16, 18]), 'x': array([8, 9])}
{'y': array([0, 2]), 'x': array([0, 1])}
Actual:
{'y': array([0, 2]), 'x': array([0, 1])}
{'y': array([4, 6]), 'x': array([2, 3])}
{'y': array([8, 10]), 'x': array([4, 5])}
Thanks in advance for any help!
Instead of going through the tf.Dataset.iterator()
, why don't you skip the first X batches?
Let's say you want 10 Batches and each has 32 elements, means total 320 elements. So you can skip these using tf.Dataset.skip(320)
(skip) which gives you dataset with first 10 Batches skipped.