why the encode_batch function need 3 inputs and i show it 3. and it cant work.the print also shows it has <PaddedBatchDataset shapes: ((2, None, None, 3), (2, None, 4), (2, None)), types: (tf.float32, tf.float32, tf.int32)> 3 inputs for it to map
the code follows
label_encoder = LabelEncoder()
(train_dataset, val_dataset), dataset_info = tfds.load(
"coco/2017", split=["train", "validation"], with_info=True, data_dir="data"
)
print(train_dataset)
autotune = tf.data.experimental.AUTOTUNE
train_dataset = train_dataset.map(preprocess_data, num_parallel_calls=autotune)
print(train_dataset)
train_dataset = train_dataset.shuffle(8 * batch_size)
print(train_dataset)
train_dataset = train_dataset.padded_batch(
batch_size=batch_size, padding_values=(0.0, 1e-8, -1), drop_remainder=True
)
print(train_dataset)
print(label_encoder.encode_batch)
train_dataset = train_dataset.map(label_encoder.encode_batch(), num_parallel_calls=autotune)
train_dataset = train_dataset.apply(tf.data.experimental.ignore_errors())
train_dataset = train_dataset.prefetch(autotune)
val_dataset = val_dataset.map(preprocess_data, num_parallel_calls=autotune)
val_dataset = val_dataset.padded_batch(
batch_size=1, padding_values=(0.0, 1e-8, -1), drop_remainder=True
)
val_dataset = val_dataset.map(label_encoder.encode_batch(), num_parallel_calls=autotune)
val_dataset = val_dataset.apply(tf.data.experimental.ignore_errors())
val_dataset = val_dataset.prefetch(autotune)
output follows:
<PrefetchDataset shapes: {image: (None, None, 3), image/filename: (), image/id: (), objects: {area: (None,), bbox: (None, 4), id: (None,), is_crowd: (None,), label: (None,)}}, types: {image: tf.uint8, image/filename: tf.string, image/id: tf.int64, objects: {area: tf.int64, bbox: tf.float32, id: tf.int64, is_crowd: tf.bool, label: tf.int64}}>
<ParallelMapDataset shapes: ((None, None, 3), (None, 4), (None,)), types: (tf.float32, tf.float32, tf.int32)>
<ShuffleDataset shapes: ((None, None, 3), (None, 4), (None,)), types: (tf.float32, tf.float32, tf.int32)>
<PaddedBatchDataset shapes: ((2, None, None, 3), (2, None, 4), (2, None)), types: (tf.float32, tf.float32, tf.int32)>
<bound method LabelEncoder.encode_batch of <__main__.LabelEncoder object at 0x000001B846D797C8>>
and show a error for
encode_batch() missing 3 required positional arguments: 'batch_images', 'gt_boxes', and 'cls_ids'
in line
train_dataset = train_dataset.map(label_encoder.encode_batch(), num_parallel_calls=autotune)
class LabelEncoder:
def __init__(self):
self._anchor_box = AnchorBox()
self._box_variance = tf.convert_to_tensor([0.1, 0.1, 0.2, 0.2], dtype=tf.float32)
def encode_batch(self, batch_images, gt_boxes, cls_ids):
"""Creates box and classification targets for a batch"""
images_shape = tf.shape(batch_images)
batch_size = images_shape[0]
labels = tf.TensorArray(dtype=tf.float32, size=batch_size, dynamic_size=True)
for i in range(batch_size):
label = self._encode_sample(images_shape, gt_boxes[i], cls_ids[i])
labels = labels.write(i, label)
batch_images = tf.keras.applications.resnet.preprocess_input(batch_images)
return batch_images, labels.stack()
these was copyed from keras https://keras.io/examples/vision/retinanet/
why the encode_batch function need 3 inputs and i show it 3. and it cant work.the print also shows it has <PaddedBatchDataset shapes: ((2, None, None, 3), (2, None, 4), (2, None)), types: (tf.float32, tf.float32, tf.int32)> 3 inputs for it to map
python version 3.7 tensorflow 2.3
When using the map
function, the function you pass to map should be callable, not the result of calling it.
The problem is on this line:
train_dataset = train_dataset.map(label_encoder.encode_batch(), num_parallel_calls=autotune)
You are calling label_encoder.encode_batch()
which requires 3 arguments but passing none.
Instead of calling the method directly, you should pass the method itself as an argument, allowing map
to supply the arguments.
Like this:
train_dataset = train_dataset.map(label_encoder.encode_batch, num_parallel_calls=autotune)
(Note the lack of parenthesis after encode_batch
)
The asme change needs to be made for this line involving val_dataset
:
val_dataset = val_dataset.map(label_encoder.encode_batch, num_parallel_calls=autotune)
Function passed to map should not be called directly.