Search code examples
pythononnx

Find input shape from onnx file


How can I find the input size of an onnx model? I would eventually like to script it from python.

With tensorflow I can recover the graph definition, find input candidate nodes from it and then obtain their size. Can I do something similar with ONNX (or even simpler)?


Solution

  • Please do NOT use input as a variable name because it's a built-in function.

    The first idea that comes to mind is that using the google.protobuf.json_format.MessageToDict() method if I need the name, data_type, or some properties of a protobuf object. For example:

    from google.protobuf.json_format import MessageToDict
    
    model = onnx.load("path/to/model.onnx")
    for _input in model.graph.input:
        print(MessageToDict(_input))
    
    

    will gives the output like:

    {'name': '0', 'type': {'tensorType': {'elemType': 2, 'shape': {'dim': [{'dimValue': '4'}, {'dimValue': '3'}, {'dimValue': '384'}, {'dimValue': '640'}]}}}}
    

    I'm not very clear whether every model.graph.input is a RepeatedCompositeContainer object or not, but it would be necessary to use the for loop when it is a RepeatedCompositeContainer.

    Then you need to get the shape information from the dim field.

    model = onnx.load("path/to/model.onnx")
    for _input in model.graph.input:
        m_dict = MessageToDict(_input))
        dim_info = m_dict.get("type").get("tensorType").get("shape").get("dim")  # ugly but we have to live with this when using dict
        input_shape = [d.get("dimValue") for d in dim_info]  # [4,3,384,640]
    

    If you need the only dim, please use message object instead.

    model = onnx.load("path/to/model.onnx")
    for _input in model.graph.input:
        dim = _input.type.tensor_type.shape.dim
        input_shape = [MessageToDict(d).get("dimValue") for d in dim] # ['4', '3', '384', '640']
        # if you prefer the python naming style, using the line below
        # input_shape = [MessageToDict(d, preserving_proto_field_name=True).get("dim_value") for d in dim]
    

    One line version:

    model = onnx.load("path/to/model.onnx")
    input_shapes = [[d.dim_value for d in _input.type.tensor_type.shape.dim] for _input in model.graph.input]
    

    Refs:

    https://github.com/googleapis/python-vision/issues/70

    AttributeError: 'google.protobuf.pyext._message.RepeatedCompositeCo' object has no attribute 'append'