Search code examples
tensorflowdeep-learningneural-networknvidia-deepstream

How can I convert NHWC to NCHW in python for deepstream


I have a TensorFlow Keras model which is stored in .pb format and from .pb format I am converting the model to .onnx format using the tf2onnx model

!python -m tf2onnx.convert --saved-model model.pb --output model.onnx 

now after converting I see that my input layer is in NHWC format and I need to convert the same to NCHW, to achieve that I am using

!python -m tf2onnx.convert --saved-model model.pb --output model_3.onnx --inputs-as-nchw input0:0

which is still giving me the same output as NHWC I have to consume the above model in NVIDIA Deepstream which only accepts NCHW format.

I found this link which talks about the transpose of the input layer, but unfortunately, that is also not working. Convert between NHWC and NCHW in TensorFlow

#import tensorflow as tf
images_nhwc = tf.compat.v1.placeholder(tf.float32, [1, 200, 300, 3])  
# input batch
out = tf.transpose(images_nhwc, [0, 3, 1, 2])
#print(out.get_shape())
model.build(out.get_shape())

enter image description here It would be really helpful if some experts can share their thoughts on how to convert NHWC to NCHW


Solution

  • I found the solution. I had to take the latest code of tf2onnx.convert.from_keras. I took the main branch from tf2onnx

    !pip install --force-reinstall  git+https://github.com/onnx/tensorflow-onnx.git@main
    !pip show tf2onnx
    !pip freeze | grep tf2onnx
    

    once that was done I was able to load the latest functionality and updated code at https://github.com/onnx/tensorflow-onnx/tree/e896723e410a59a600d1a73657f9965a3cbf2c3b .

    Below is the code I used to convert my model from .pb to .onnx along with NHWC to NCHW.

    # give the list of *inputs* which should be converted and returned *as nchw*
    _INPUT = model.input.name
    
    model_proto, external_tensor_storage = tf2onnx.convert.from_keras(model, inputs_as_nchw=[_INPUT])
    

    The biggest catch about the above code was [_INPUT] which was suppose to be a list and I was able find this information in the test cases.