I have my data in multiple pickle files stored on disk. I want to use tensorflow's tf.data.Dataset to load my data into training pipeline. My code goes:
def _parse_file(path):
image, label = *load pickle file*
return image, label
paths = glob.glob('*.pkl')
print(len(paths))
dataset = tf.data.Dataset.from_tensor_slices(paths)
dataset = dataset.map(_parse_file)
iterator = dataset.make_one_shot_iterator()
Problem is I don't know how to implement the _parse_file
fuction. The argument to this function, path
, is of tensor type. I tried
def _parse_file(path):
with tf.Session() as s:
p = s.run(path)
image, label = pickle.load(open(p, 'rb'))
return image, label
and got error message:
InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'arg0' with dtype string
[[Node: arg0 = Placeholder[dtype=DT_STRING, shape=<unknown>, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]
After some search on the Internet I still have no idea how to do it. I will be grateful to anyone providing me a hint.
I have solved this myself. I should use tf.py_func
as in this doc.