Search code examples
androidpytorchobject-detectiontensorflow-liteyolov7

Unable to use exported YOLOv7 to TensorFlow Lite, in ObjectDetector


I have a YOLOv7 model trained on my custom dataset. I exported the model to TensorFlow lite successfully and was able to use it for inference in Python. But when I try to use the same model in Android, using the object detection project with TensorFlow lite, it throws this error:

java.lang.IllegalArgumentException: Error occurred when initializing ObjectDetector: The input tensor should have dimensions 1 x height x width x 3. Got 1 x 3 x 640 x 640.

Is it possible to change the input shape for the ObjectDetector class, or export the YOLOv7 or YOLOv5 model with corresponding input shape?

I tried to tweak the export process to change the input shape of ONNX model which is the intermediate model in exporting from PyTorch to TensorFlow Lite but it throws this error:

ONNX export failure: Given groups=1, weight of size [32, 3, 3, 3], expected input[1, 640, 640, 3] to have 3 channels, but got 640 channels instead

update: I used onnx2tf to export .tflite model with NHWC input shape. Now the Android project throws this error:

java.lang.RuntimeException: Error occurred when initializing ObjectDetector: Input tensor has type kTfLiteFloat32: it requires specifying NormalizationOptions metadata to preprocess input images. I couldn't find a way to add normalization options to metadata using this doc. Any solutions?


Solution

    1. The ONNX export error: Rearrange the dimensions to NHWC format. Then, when exporting from ONNX to TensorFlow, ensure that the NHWC format is retained.

      nhwc_tensor = input_tensor.permute(0, 2, 3, 1)

    2. TensorFlow Lite Input Preprocessing: The error Input tensor has type kTfLiteFloat32: it requires specifying NormalizationOptions metadata to preprocess input images. (You need to specify how the input image should be normalized before passing it through the model).

    Add normalization metadata to your .tflite model:

    import tflite_support
    from tflite_support.metadata_writers import image_classifier
    from tflite_support.metadata_writers import writer_utils
    from tflite_support import metadata
    
    NORMALIZATION_MEAN = [127.5]
    NORMALIZATION_STD = [127.5]
    
    writer = image_classifier.MetadataWriter.create_for_inference(
        writer_utils.load_file("your_model.tflite"),
        input_norm_mean=NORMALIZATION_MEAN,
        input_norm_std=NORMALIZATION_STD)
    
    writer_utils.save_file(writer.populate(), "your_model_tflite_here")