`py_function` causes `ragged_batch()` not working in ``

I'm working on an object detection project, and use input pipeline to load local data. Because object detection requires not only image but also annotations, and the different dimension of annotations makes it even harder. I tried several ways but none of them works. Here's my attempts, and I'm exhausted of ideas. Very appreciate for your help!

Parse XML

My local data is in Pascal VOC format. First, I used .from_tensor_slices() to get annotation_files paths, and parse them to get image path, and finally .ragged_batch() them. But during .map(load), it automatically converted string into Tensor("args_0:0", shape=(), dtype=string), which cannot be used in many libraries like XML parser ElementTree. Then I used tf.py_function() to convert it back into Python string. And them find it a TensorFlow bug:

annotation_files = [

def load(annotationFile):
    # load annotation (boxes, class ids)
    def _loadAnnotation(annotationFile):
        thisBoxes = []
        thisClassIDs = []
        annotationFile = annotationFile.numpy().decode("utf-8")
        root = ET.parse(annotationFile).getroot()
        for object in root.findall("object"):
            # load bounding boxes
            bndbox = object.find("bndbox")
            xmin = int(bndbox.find("xmin").text)
            ymin = int(bndbox.find("ymin").text)
            xmax = int(bndbox.find("xmax").text)
            ymax = int(bndbox.find("ymax").text)
            thisBoxes.append([xmin, ymin, xmax, ymax])
            # load class IDs
            className = object.find("name").text
            classID = classNames.index(className)
        # image file path
        imageFile = imageFolder + "/" + root.find('filename').text
        return (imageFile, tf.cast(thisBoxes, dtype=tf.float32), tf.cast(thisClassIDs, dtype=tf.float32))
    imageFile, thisBoxes, thisClassIDs = tf.py_function(_loadAnnotation, [annotationFile], [tf.string, tf.float32, tf.float32])
    # load image
    image =
    image = tf.image.decode_jpeg(image, channels=3)

    # package annotation (boxes, class ids) to dictionary
    bounding_boxes = {
        "boxes": tf.cast(thisBoxes, dtype=tf.float32),
        "classes": tf.cast(thisClassIDs, dtype=tf.float32)

    return {"images": tf.cast(image, dtype=tf.float32), "bounding_boxes": bounding_boxes}

dataset =
dataset =
dataset = dataset.ragged_batch(4)


Then, I tried to package one record of data into a single file with pickle to prevent parsing with ET. Unfortunately, pickle also needs Python string. Same problem as first attempt, it not works.


After that, I tried to store data into TFRecord and load them with But problem comes when writing TFRecord. It comes an error TypeError: Value must be iterable. During search I find this discussion. It seems I must reshape my tensor to flatten it, and then bring it back to N-dimension when using. But because the unknown dimension of bounding boxes, which is why I want to use ragged_batch, it's impossible for me to flatten them.

def serializeTFRecord(data):
    image = data["images"]
    classes = data["bounding_boxes"]["classes"]
    boxes = data["bounding_boxes"]["boxes"]

    feature = {
        "images": tf.train.Feature(float_list=tf.train.FloatList(value=image)),
        "bounding_boxes": {
            "classes": tf.train.Feature(float_list=tf.train.FloatList(value=classes)),
            "boxes": tf.train.Feature(float_list=tf.train.FloatList(value=boxes))

    exampleProto = tf.train.Example(features=tf.train.Features(feature=feature))
    return exampleProto.SerializeToString()


  • After an arduous attempting and trying, I eventually come with an idea of how to fix this py_function bug. According to this GitHub discussion, the return value of py_function lost the information of its shape and rank. So the easiest solution can bring it back to life is to manually set those information by tensor.set_shape([None, None]), the number of None should be the dimension (the number of axis) of that tensor. Below are a small demo showing that it works.

    import tensorflow as tf
    def processing(data):
        def _processing(data):
            arr = [range(data), range(data)]
            arr = tf.cast(arr, tf.float32)
            print(f"Inside py_function arr shape: {arr.shape}")
            return arr
        arr = tf.py_function(_processing, [data], tf.float32)
        print(f"Outside py_function arr shape: {arr.shape}")
        arr.set_shape([None, None])
        return arr
    list = [1,2,3,1]
    ds =
    ds =
    ds = ds.ragged_batch(4)
    for data in ds: