Search code examples
tensorflowpickle

How to load pickle files by tensorflow's tf.data API


I have my data in multiple pickle files stored on disk. I want to use tensorflow's tf.data.Dataset to load my data into training pipeline. My code goes:

def _parse_file(path):
    image, label = *load pickle file*
    return image, label
paths = glob.glob('*.pkl')
print(len(paths))
dataset = tf.data.Dataset.from_tensor_slices(paths)
dataset = dataset.map(_parse_file)
iterator = dataset.make_one_shot_iterator()

Problem is I don't know how to implement the _parse_file fuction. The argument to this function, path, is of tensor type. I tried

def _parse_file(path):
    with tf.Session() as s:
        p = s.run(path)
        image, label = pickle.load(open(p, 'rb'))
    return image, label

and got error message:

InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'arg0' with dtype string
     [[Node: arg0 = Placeholder[dtype=DT_STRING, shape=<unknown>, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

After some search on the Internet I still have no idea how to do it. I will be grateful to anyone providing me a hint.


Solution

  • I have solved this myself. I should use tf.py_func as in this doc.