Search code examples
pythontensorflowkerasprototensorrt

TypeError: graph_def must be a GraphDef proto


I convert keras model.h5 to frozen_graph.pb inorder to optimize and run on jetson. But optimizing the frozen_graph.pb shots error.

raise TypeError('graph_def must be a GraphDef proto.') TypeError: graph_def must be a GraphDef proto.

Code :

import tensorflow.contrib.tensorrt as trt

frozen_graph = './model/frozen_model.pb'
output_names = ['conv2d_59','conv2d_67','conv2d_75']

trt_graph = trt.create_inference_graph(
    input_graph_def=frozen_graph,
    outputs=output_names,
    max_batch_size=1,
    max_workspace_size_bytes=1 << 25,
    precision_mode='FP16',
    minimum_segment_size=50
)

graph_io.write_graph(trt_graph, "./model/",
                     "trt_graph.pb", as_text=False)

NOTE: importing import tensorflow.contrib.tensorrt as trt and graph.io got some problem.

Reference link: https://www.dlology.com/blog/how-to-run-keras-model-on-jetson-nano/

Error log:

> Traceback (most recent call last):   File "to_tensorrt.py", line 12,
> in <module>
>     **minimum_segment_size=50**   File "/home/christie/yolo_keras/yolo-keras/lib/python3.6/site-packages/tensorflow/contrib/tensorrt/python/trt_convert.py",
> line 51, in create_inference_graph
>     session_config=session_config)   File "/home/christie/yolo_keras/yolo-keras/lib/python3.6/site-packages/tensorflow/python/compiler/tensorrt/trt_convert.py", line 1146, in create_inference_graph
>     converted_graph_def = trt_converter.convert()   File "/home/christie/yolo_keras/yolo-keras/lib/python3.6/site-packages/tensorflow/python/compiler/tensorrt/trt_convert.py", line 298, in convert
>     self._convert_graph_def()   File "/home/christie/yolo_keras/yolo-keras/lib/python3.6/site-packages/tensorflow/python/compiler/tensorrt/trt_convert.py", line 221, in _convert_graph_def
>     importer.import_graph_def(self._input_graph_def, name="")   File "/home/christie/yolo_keras/yolo-keras/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py",
> line 507, in new_func
>     return func(*args, **kwargs)   File "/home/christie/yolo_keras/yolo-keras/lib/python3.6/site-packages/tensorflow/python/framework/importer.py",
> line 394, in import_graph_def
>     graph_def = _ProcessGraphDefParam(graph_def, op_dict)   File "/home/christie/yolo_keras/yolo-keras/lib/python3.6/site-packages/tensorflow/python/framework/importer.py",
> line 96, in _ProcessGraphDefParam
>     raise TypeError('graph_def must be a GraphDef proto.') TypeError: graph_def must be a GraphDef proto.

Solution

  • You have to first parse the content from the file and then pass it as a parameter:

    import tensorflow.contrib.tensorrt as trt
    
    frozen_graph = './model/frozen_model.pb'
    output_names = ['conv2d_59','conv2d_67','conv2d_75']
    
    # Read graph def (binary format)
    with open(frozen_graph, 'rb') as f:
        frozen_graph_gd = tf.GraphDef()
        frozen_graph_gd.ParseFromString(f.read())
    
    # If frozen graph is in text format load it like this
    # import google.protobuf.text_format
    # with open(frozen_graph, 'r') as f:
    #     frozen_graph_gd = google.protobuf.text_format.Parse(f.read(), tf.GraphDef())
    
    trt_graph = trt.create_inference_graph(
        input_graph_def=frozen_graph_gd,  # Pass the parsed graph def here
        outputs=output_names,
        max_batch_size=1,
        max_workspace_size_bytes=1 << 25,
        precision_mode='FP16',
        minimum_segment_size=50
    )
    
    graph_io.write_graph(trt_graph, "./model/",
                         "trt_graph.pb", as_text=False)