Search code examples
tensorflowtensorrt

Data input format to Plugin


I have tensorflow model as shown in the picture. And trying to convert to TensorRT. enter image description here

Need a plugin for upsample node in converting to TensorRT.

For Tensorflow implementation, say input size is 1x3x4x19 shape and upsampled to 1x12x14x19 tensor.

Implemented the same thing in TensorRT plugin and the considered process flow is as follow.

1x3x4x19 tensor in tensorflow is

[[[...19channel data ...],[...19channel data ...],[...19channel data ...],[...19channel data ...]],

 [[...19channel data ...],[...19channel data ...],[...19channel data ...],[...19channel data ...]],

 [[...19channel data ...],[...19channel data ...],[...19channel data ...],[...19channel data ...]]]

flattened first into

[...19channel data ..., ...19channel data ..., ...19channel data ..., etc.,...19channel data ...]

the flattened datalength is 228.

19 channels data is quite difficult to visualize.

So 3 channels data is used as example again for the flattened data.

[[[1,2,3],[4,5,6],[7,8,9],[10,11,12]],
 [[1,2,3],[4,5,6],[7,8,9],[10,11,12]],
 [[1,2,3],[4,5,6],[7,8,9],[10,11,12]]]

The flattened array for 3 channels data is

[1,2,3,4,5,6,7,8,9,10,11,12,1,2,3,4,5,6,7,8,9,10,11,12,1,2,3,4,5,6,7,8,9,10,11,12]

That flattened data is input into plugin for upsample. My CUDA code in plugin expected the flattened data as shown above.

But the plugin output is weired and not same as Tensorflow's upsample output. I checked using Openpose1(shown in picture) data to Tensorflow's upsample operation.

Is the correct data format to plugin in TensorRT engine? If not, how input data is fed to plugin?

The plugin output looks like input is vertically flattened like

[1,1,1,2,2,2,3,3,3...etc.]

Solution

  • After taking sometime, the problem is solved. The plugin input is NCHW format. It is like 1,4,7,10,1,4,7,10,1,4,7,10,2,5,8,11,2,5,8,11,2,5,8,11,3,6,9,12,3,6,9,12,3,6,9,12.

    So need to work on that data and reformat back to NHWC format if match Tensorflow's NHWC format.