Search code examples
pythontensorflowkerasconv-neural-networkmobilenet

Tensorflow: Error when trying transfer learning: Invalid JPEG data or crop window


I am trying to shape my own custom image dataset into the correct input shape for the pretrained MobileNet model on Tensorflow using their tutorial here. My code:

batch_size = 256
epochs = 15
IMG_HEIGHT = 160
IMG_WIDTH = 160
BATCH_SIZE = 256
SHUFFLE_BUFFER_SIZE = 1000
traindir = pathlib.Path('/train')
valdir = pathlib.Path('/validation')
list_ds = tf.data.Dataset.list_files(str(traindir/'*/*'))
val_list_ds = tf.data.Dataset.list_files(str(valdir/'*/*'))
CLASS_NAMES = np.array([item.name for item in valdir.glob('*') if item.name != "LICENSE.txt"])
def get_label(file_path):
  # convert the path to a list of path components
  parts = tf.strings.split(file_path, os.path.sep)
  # The second to last is the class-directory
  return parts[-2] == CLASS_NAMES

def decode_img(img):
  # convert the compressed string to a 3D uint8 tensor
  img = tf.image.decode_jpeg(img, channels=3)
  # Use `convert_image_dtype` to convert to floats in the [0,1] range.
  img = tf.image.convert_image_dtype(img, tf.float32)
  # resize the image to the desired size.
  return tf.image.resize(img, [IMG_HEIGHT, IMG_WIDTH])

def process_path(file_path):
  label = get_label(file_path)
  # load the raw data from the file as a string
  img = tf.io.read_file(file_path)
  img = decode_img(img)
  return img, label
# Set `num_parallel_calls` so multiple images are loaded/processed in parallel.
labeled_ds = list_ds.map(process_path, num_parallel_calls=5)
labeled_val_ds = val_list_ds.map(process_path, num_parallel_calls=5)
train_batches = labeled_ds.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
validation_batches = labeled_val_ds.batch(BATCH_SIZE)
for image_batch, label_batch in train_batches.take(1):
   pass

image_batch.shape

After which I continue with the TF tutorial on transfer learning here. However, I ran into this problem where I suspect the JPEG image is corrupted or there is a lack of/problem with the iterator?:

Epoch 1/10
 21/330 [>.............................] - ETA: 14:02 - loss: 3.9893 - accuracy: 0.0326
---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-87-11afdc6d5aef> in <module>
      1 history = model.fit(train_batches,
      2                     epochs=initial_epochs,
----> 3                     validation_data=validation_batches)

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\keras\engine\training.py in _method_wrapper(self, *args, **kwargs)
     64   def _method_wrapper(self, *args, **kwargs):
     65     if not self._in_multi_worker_mode():  # pylint: disable=protected-access
---> 66       return method(self, *args, **kwargs)
     67 
     68     # Running inside `run_distribute_coordinator` already.

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\keras\engine\training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
    846                 batch_size=batch_size):
    847               callbacks.on_train_batch_begin(step)
--> 848               tmp_logs = train_function(iterator)
    849               # Catch OutOfRangeError for Datasets of unknown size.
    850               # This blocks until the batch has finished executing.

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\eager\def_function.py in __call__(self, *args, **kwds)
    578         xla_context.Exit()
    579     else:
--> 580       result = self._call(*args, **kwds)
    581 
    582     if tracing_count == self._get_tracing_count():

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\eager\def_function.py in _call(self, *args, **kwds)
    609       # In this case we have created variables on the first call, so we run the
    610       # defunned version which is guaranteed to never create variables.
--> 611       return self._stateless_fn(*args, **kwds)  # pylint: disable=not-callable
    612     elif self._stateful_fn is not None:
    613       # Release the lock early so that multiple threads can perform the call

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\eager\function.py in __call__(self, *args, **kwargs)
   2418     with self._lock:
   2419       graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
-> 2420     return graph_function._filtered_call(args, kwargs)  # pylint: disable=protected-access
   2421 
   2422   @property

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\eager\function.py in _filtered_call(self, args, kwargs)
   1663          if isinstance(t, (ops.Tensor,
   1664                            resource_variable_ops.BaseResourceVariable))),
-> 1665         self.captured_inputs)
   1666 
   1667   def _call_flat(self, args, captured_inputs, cancellation_manager=None):

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\eager\function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
   1744       # No tape is watching; skip to running the function.
   1745       return self._build_call_outputs(self._inference_function.call(
-> 1746           ctx, args, cancellation_manager=cancellation_manager))
   1747     forward_backward = self._select_forward_and_backward_functions(
   1748         args,

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\eager\function.py in call(self, ctx, args, cancellation_manager)
    596               inputs=args,
    597               attrs=attrs,
--> 598               ctx=ctx)
    599         else:
    600           outputs = execute.execute_with_cancellation(

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\eager\execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     58     ctx.ensure_initialized()
     59     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 60                                         inputs, attrs, num_outputs)
     61   except core._NotOkStatusException as e:
     62     if name is not None:

InvalidArgumentError: 2 root error(s) found.
  (0) Invalid argument:  Invalid JPEG data or crop window, data size 34228
     [[{{node DecodeJpeg}}]]
     [[IteratorGetNext]]
  (1) Invalid argument:  Invalid JPEG data or crop window, data size 34228
     [[{{node DecodeJpeg}}]]
     [[IteratorGetNext]]
     [[IteratorGetNext/_4]]
0 successful operations.
0 derived errors ignored. [Op:__inference_train_function_30787]

Function call stack:
train_function -> train_function

Thank you for your time! Edit: After re-running the code a few times, it seems it will produce the same errors with different data sizes like 16384....

Edit: Yes, the problem lies with the fact that some .jpeg are actually .png in disguise, or they are just plain corrupted. I highly recommend checking data integrity before training any model with the data.


Solution

  • I faced similar problem. there is a problem in some of your training data. you can use code below to check which jpeg image is corrupted and delete it.

    from struct import unpack
    from tqdm import tqdm
    import os
    
    
    marker_mapping = {
        0xffd8: "Start of Image",
        0xffe0: "Application Default Header",
        0xffdb: "Quantization Table",
        0xffc0: "Start of Frame",
        0xffc4: "Define Huffman Table",
        0xffda: "Start of Scan",
        0xffd9: "End of Image"
    }
    
    
    class JPEG:
        def __init__(self, image_file):
            with open(image_file, 'rb') as f:
                self.img_data = f.read()
        
        def decode(self):
            data = self.img_data
            while(True):
                marker, = unpack(">H", data[0:2])
                # print(marker_mapping.get(marker))
                if marker == 0xffd8:
                    data = data[2:]
                elif marker == 0xffd9:
                    return
                elif marker == 0xffda:
                    data = data[-2:]
                else:
                    lenchunk, = unpack(">H", data[2:4])
                    data = data[2+lenchunk:]            
                if len(data)==0:
                    break        
    
    
    bads = []
    
    for img in tqdm(images):
      image = osp.join(root_img,img)
      image = JPEG(image) 
      try:
        image.decode()   
      except:
        bads.append(img)
    
    
    for name in bads:
      os.remove(osp.join(root_img,name))
    

    I used yasoob script to decode jpeg image.