Search code examples
pythontensorflowkeras

Why the map function cant recognize these 3 inputs


why the encode_batch function need 3 inputs and i show it 3. and it cant work.the print also shows it has <PaddedBatchDataset shapes: ((2, None, None, 3), (2, None, 4), (2, None)), types: (tf.float32, tf.float32, tf.int32)> 3 inputs for it to map

the code follows

label_encoder = LabelEncoder()
(train_dataset, val_dataset), dataset_info = tfds.load(
"coco/2017", split=["train", "validation"], with_info=True, data_dir="data"
)
print(train_dataset)
autotune = tf.data.experimental.AUTOTUNE
train_dataset = train_dataset.map(preprocess_data, num_parallel_calls=autotune)
print(train_dataset)
train_dataset = train_dataset.shuffle(8 * batch_size)
print(train_dataset)
train_dataset = train_dataset.padded_batch(
batch_size=batch_size, padding_values=(0.0, 1e-8, -1), drop_remainder=True
)
print(train_dataset)
print(label_encoder.encode_batch)
train_dataset = train_dataset.map(label_encoder.encode_batch(), num_parallel_calls=autotune)
train_dataset = train_dataset.apply(tf.data.experimental.ignore_errors())
train_dataset = train_dataset.prefetch(autotune)

val_dataset = val_dataset.map(preprocess_data, num_parallel_calls=autotune)
val_dataset = val_dataset.padded_batch(
batch_size=1, padding_values=(0.0, 1e-8, -1), drop_remainder=True
)
val_dataset = val_dataset.map(label_encoder.encode_batch(), num_parallel_calls=autotune)
val_dataset = val_dataset.apply(tf.data.experimental.ignore_errors())
val_dataset = val_dataset.prefetch(autotune)

output follows:

<PrefetchDataset shapes: {image: (None, None, 3), image/filename: (), image/id: (), objects: {area: (None,), bbox: (None, 4), id: (None,), is_crowd: (None,), label: (None,)}}, types: {image: tf.uint8, image/filename: tf.string, image/id: tf.int64, objects: {area: tf.int64, bbox: tf.float32, id: tf.int64, is_crowd: tf.bool, label: tf.int64}}>
<ParallelMapDataset shapes: ((None, None, 3), (None, 4), (None,)), types: (tf.float32, tf.float32, tf.int32)>
<ShuffleDataset shapes: ((None, None, 3), (None, 4), (None,)), types: (tf.float32, tf.float32, tf.int32)>
<PaddedBatchDataset shapes: ((2, None, None, 3), (2, None, 4), (2, None)), types: (tf.float32, tf.float32, tf.int32)>
<bound method LabelEncoder.encode_batch of <__main__.LabelEncoder object at 0x000001B846D797C8>>

and show a error for

encode_batch() missing 3 required positional arguments: 'batch_images', 'gt_boxes', and 'cls_ids'

in line

train_dataset = train_dataset.map(label_encoder.encode_batch(), num_parallel_calls=autotune)
class LabelEncoder:
def __init__(self):
self._anchor_box = AnchorBox()
self._box_variance = tf.convert_to_tensor([0.1, 0.1, 0.2, 0.2], dtype=tf.float32)
    def encode_batch(self, batch_images, gt_boxes, cls_ids):
        """Creates box and classification targets for a batch"""
        images_shape = tf.shape(batch_images)
        batch_size = images_shape[0]
        labels = tf.TensorArray(dtype=tf.float32, size=batch_size, dynamic_size=True)
        for i in range(batch_size):
            label = self._encode_sample(images_shape, gt_boxes[i], cls_ids[i])
            labels = labels.write(i, label)
        batch_images = tf.keras.applications.resnet.preprocess_input(batch_images)
        return batch_images, labels.stack()

these was copyed from keras https://keras.io/examples/vision/retinanet/

why the encode_batch function need 3 inputs and i show it 3. and it cant work.the print also shows it has <PaddedBatchDataset shapes: ((2, None, None, 3), (2, None, 4), (2, None)), types: (tf.float32, tf.float32, tf.int32)> 3 inputs for it to map

python version 3.7 tensorflow 2.3


Solution

  • When using the map function, the function you pass to map should be callable, not the result of calling it.

    The problem is on this line:

    train_dataset = train_dataset.map(label_encoder.encode_batch(), num_parallel_calls=autotune)
    

    You are calling label_encoder.encode_batch() which requires 3 arguments but passing none.

    Instead of calling the method directly, you should pass the method itself as an argument, allowing map to supply the arguments.

    Like this:

    train_dataset = train_dataset.map(label_encoder.encode_batch, num_parallel_calls=autotune)
    

    (Note the lack of parenthesis after encode_batch)

    The asme change needs to be made for this line involving val_dataset:

    val_dataset = val_dataset.map(label_encoder.encode_batch, num_parallel_calls=autotune)
    

    Function passed to map should not be called directly.