I'm trying to convert a ViT-B/32
Vision Transformer model from the UNICOM repository on a Jetson Orin Nano. The model's Vision Transformer class and source code is here.
I use the following code to convert the model to ONNX:
import torch
import onnx
import onnxruntime
from unicom.vision_transformer import build_model
if __name__ == '__main__':
model_name = "ViT-B/32"
model_name_fp16 = "FP16-ViT-B-32"
onnx_model_path = f"{model_name_fp16}.onnx"
model = build_model(model_name)
model.eval()
model = model.to('cuda')
torch_input = torch.randn(1, 3, 224, 224).to('cuda')
onnx_program = torch.onnx.dynamo_export(model, torch_input)
onnx_program.save(onnx_model_path)
onnx_model = onnx.load(onnx_model_path)
onnx.checker.check_model(onnx_model_path)
I then use the following command line to convert the ONNX model to a TensorRT engine:
/usr/src/tensorrt/bin/trtexec --onnx=FP16-ViT-B-32.onnx --saveEngine=FP16-ViT-B-32.trt --workspace=1024 --fp16
This results in the following error:
--workspace flag has been deprecated by --memPoolSize flag.
=== Model Options ===
Format: ONNX
Model: /home/jetson/HPS/Models/FeatureExtractor/UNICOM/ONNX/FP16-ViT-B-32.onnx
Output:
=== Build Options ===
Max batch: explicit batch
Memory Pools: workspace: 1024 MiB, dlaSRAM: default, dlaLocalDRAM: default, dlaGlobalDRAM: default
minTiming: 1
avgTiming: 8
Precision: FP32+FP16
LayerPrecisions:
Layer Device Types:
Calibration:
Refit: Disabled
Version Compatible: Disabled
ONNX Native InstanceNorm: Disabled
TensorRT runtime: full
Lean DLL Path:
Tempfile Controls: { in_memory: allow, temporary: allow }
Exclude Lean Runtime: Disabled
Sparsity: Disabled
Safe mode: Disabled
Build DLA standalone loadable: Disabled
Allow GPU fallback for DLA: Disabled
DirectIO mode: Disabled
Restricted mode: Disabled
Skip inference: Disabled
Save engine: /home/jetson/HPS/Models/FeatureExtractor/UNICOM/ONNX/FP16-ViT-B-32.trt
Load engine:
Profiling verbosity: 0
Tactic sources: Using default tactic sources
timingCacheMode: local
timingCacheFile:
Heuristic: Disabled
Preview Features: Use default preview flags.
MaxAuxStreams: -1
BuilderOptimizationLevel: -1
Input(s)s format: fp32:CHW
Output(s)s format: fp32:CHW
Input build shapes: model
Input calibration shapes: model
=== System Options ===
Device: 0
DLACore:
Plugins:
setPluginsToSerialize:
dynamicPlugins:
ignoreParsedPluginLibs: 0
=== Inference Options ===
Batch: Explicit
Input inference shapes: model
Iterations: 10
Duration: 3s (+ 200ms warm up)
Sleep time: 0ms
Idle time: 0ms
Inference Streams: 1
ExposeDMA: Disabled
Data transfers: Enabled
Spin-wait: Disabled
Multithreading: Disabled
CUDA Graph: Disabled
Separate profiling: Disabled
Time Deserialize: Disabled
Time Refit: Disabled
NVTX verbosity: 0
Persistent Cache Ratio: 0
Inputs:
=== Reporting Options ===
Verbose: Disabled
Averages: 10 inferences
Percentiles: 90,95,99
Dump refittable layers:Disabled
Dump output: Disabled
Profile: Disabled
Export timing to JSON file:
Export output to JSON file:
Export profile to JSON file:
=== Device Information ===
Selected Device: Orin
Compute Capability: 8.7
SMs: 8
Device Global Memory: 7620 MiB
Shared Memory per SM: 164 KiB
Memory Bus Width: 128 bits (ECC disabled)
Application Compute Clock Rate: 0.624 GHz
Application Memory Clock Rate: 0.624 GHz
Note: The application clock rates do not reflect the actual clock rates that the GPU is currently running at.
TensorRT version: 8.6.2
Loading standard plugins
[MemUsageChange] Init CUDA: CPU +2, GPU +0, now: CPU 33, GPU 4508 (MiB)
[MemUsageChange] Init builder kernel library: CPU +1154, GPU +1351, now: CPU 1223, GPU 5866 (MiB)
Start parsing network model.
----------------------------------------------------------------
Input filename: /home/jetson/HPS/Models/FeatureExtractor/UNICOM/ONNX/FP16-ViT-B-32.onnx
ONNX IR version: 0.0.8
Opset version: 1
Producer name: pytorch
Producer version: 2.3.0
Domain:
Model version: 0
Doc string:
----------------------------------------------------------------
No importer registered for op: unicom_vision_transformer_PatchEmbedding_patch_embed_1. Attempting to import as plugin.
Searching for plugin: unicom_vision_transformer_PatchEmbedding_patch_embed_1, plugin_version: 1, plugin_namespace:
3: getPluginCreator could not find plugin: unicom_vision_transformer_PatchEmbedding_patch_embed_1 version: 1
ModelImporter.cpp:768: While parsing node number 0 [unicom_vision_transformer_PatchEmbedding_patch_embed_1 -> "patch_embed_1"]:
ModelImporter.cpp:769: --- Begin node ---
ModelImporter.cpp:770: input: "l_x_"
--workspace flag has been deprecated by --memPoolSize flag.
=== Model Options ===
Format: ONNX
Model: /home/jetson/HPS/Models/FeatureExtractor/UNICOM/ONNX/FP16-ViT-B-32.onnx
Output:
=== Build Options ===
Max batch: explicit batch
Memory Pools: workspace: 1024 MiB, dlaSRAM: default, dlaLocalDRAM: default, dlaGlobalDRAM: default
minTiming: 1
avgTiming: 8
Precision: FP32+FP16
LayerPrecisions:
Layer Device Types:
Calibration:
Refit: Disabled
Version Compatible: Disabled
ONNX Native InstanceNorm: Disabled
TensorRT runtime: full
Lean DLL Path:
Tempfile Controls: { in_memory: allow, temporary: allow }
Exclude Lean Runtime: Disabled
Sparsity: Disabled
Safe mode: Disabled
Build DLA standalone loadable: Disabled
Allow GPU fallback for DLA: Disabled
DirectIO mode: Disabled
Restricted mode: Disabled
Skip inference: Disabled
Save engine: /home/jetson/HPS/Models/FeatureExtractor/UNICOM/ONNX/FP16-ViT-B-32.trt
Load engine:
Profiling verbosity: 0
Tactic sources: Using default tactic sources
timingCacheMode: local
timingCacheFile:
Heuristic: Disabled
Preview Features: Use default preview flags.
MaxAuxStreams: -1
BuilderOptimizationLevel: -1
Input(s)s format: fp32:CHW
Output(s)s format: fp32:CHW
Input build shapes: model
Input calibration shapes: model
=== System Options ===
Device: 0
DLACore:
Plugins:
setPluginsToSerialize:
dynamicPlugins:
ignoreParsedPluginLibs: 0
=== Inference Options ===
Batch: Explicit
Input inference shapes: model
Iterations: 10
Duration: 3s (+ 200ms warm up)
Sleep time: 0ms
Idle time: 0ms
Inference Streams: 1
ExposeDMA: Disabled
Data transfers: Enabled
Spin-wait: Disabled
Multithreading: Disabled
CUDA Graph: Disabled
Separate profiling: Disabled
Time Deserialize: Disabled
Time Refit: Disabled
NVTX verbosity: 0
Persistent Cache Ratio: 0
Inputs:
=== Reporting Options ===
Verbose: Enabled
Averages: 10 inferences
Percentiles: 90,95,99
Dump refittable layers:Disabled
Dump output: Disabled
Profile: Disabled
Export timing to JSON file:
Export output to JSON file:
Export profile to JSON file:
=== Device Information ===
Selected Device: Orin
Compute Capability: 8.7
SMs: 8
Device Global Memory: 7620 MiB
Shared Memory per SM: 164 KiB
Memory Bus Width: 128 bits (ECC disabled)
Application Compute Clock Rate: 0.624 GHz
Application Memory Clock Rate: 0.624 GHz
Note: The application clock rates do not reflect the actual clock rates that the GPU is currently running at.
TensorRT version: 8.6.2
Loading standard plugins
Registered plugin - ::BatchedNMSDynamic_TRT version 1
Registered plugin - ::BatchedNMS_TRT version 1
Registered plugin - ::BatchTilePlugin_TRT version 1
Registered plugin - ::Clip_TRT version 1
Registered plugin - ::CoordConvAC version 1
Registered plugin - ::CropAndResizeDynamic version 1
Registered plugin - ::CropAndResize version 1
Registered plugin - ::DecodeBbox3DPlugin version 1
Registered plugin - ::DetectionLayer_TRT version 1
Registered plugin - ::EfficientNMS_Explicit_TF_TRT version 1
Registered plugin - ::EfficientNMS_Implicit_TF_TRT version 1
Registered plugin - ::EfficientNMS_ONNX_TRT version 1
Registered plugin - ::EfficientNMS_TRT version 1
Registered plugin - ::FlattenConcat_TRT version 1
Registered plugin - ::GenerateDetection_TRT version 1
Registered plugin - ::GridAnchor_TRT version 1
Registered plugin - ::GridAnchorRect_TRT version 1
Registered plugin - ::InstanceNormalization_TRT version 1
Registered plugin - ::InstanceNormalization_TRT version 2
Registered plugin - ::LReLU_TRT version 1
Registered plugin - ::ModulatedDeformConv2d version 1
Registered plugin - ::MultilevelCropAndResize_TRT version 1
Registered plugin - ::MultilevelProposeROI_TRT version 1
Registered plugin - ::MultiscaleDeformableAttnPlugin_TRT version 1
Registered plugin - ::NMSDynamic_TRT version 1
Registered plugin - ::NMS_TRT version 1
Registered plugin - ::Normalize_TRT version 1
Registered plugin - ::PillarScatterPlugin version 1
Registered plugin - ::PriorBox_TRT version 1
Registered plugin - ::ProposalDynamic version 1
Registered plugin - ::ProposalLayer_TRT version 1
Registered plugin - ::Proposal version 1
Registered plugin - ::PyramidROIAlign_TRT version 1
Registered plugin - ::Region_TRT version 1
Registered plugin - ::Reorg_TRT version 1
Registered plugin - ::ResizeNearest_TRT version 1
Registered plugin - ::ROIAlign_TRT version 1
Registered plugin - ::RPROI_TRT version 1
Registered plugin - ::ScatterND version 1
Registered plugin - ::SpecialSlice_TRT version 1
Registered plugin - ::Split version 1
Registered plugin - ::VoxelGeneratorPlugin version 1
[MemUsageChange] Init CUDA: CPU +2, GPU +0, now: CPU 33, GPU 5167 (MiB)
Trying to load shared library libnvinfer_builder_resource.so.8.6.2
Loaded shared library libnvinfer_builder_resource.so.8.6.2
[MemUsageChange] Init builder kernel library: CPU +1154, GPU +995, now: CPU 1223, GPU 6203 (MiB)
CUDA lazy loading is enabled.
Start parsing network model.
----------------------------------------------------------------
Input filename: /home/jetson/HPS/Models/FeatureExtractor/UNICOM/ONNX/FP16-ViT-B-32.onnx
ONNX IR version: 0.0.8
Opset version: 1
Producer name: pytorch
Producer version: 2.3.0
Domain:
Model version: 0
Doc string:
----------------------------------------------------------------
Plugin already registered - ::BatchedNMSDynamic_TRT version 1
Plugin already registered - ::BatchedNMS_TRT version 1
Plugin already registered - ::BatchTilePlugin_TRT version 1
Plugin already registered - ::Clip_TRT version 1
Plugin already registered - ::CoordConvAC version 1
Plugin already registered - ::CropAndResizeDynamic version 1
Plugin already registered - ::CropAndResize version 1
Plugin already registered - ::DecodeBbox3DPlugin version 1
Plugin already registered - ::DetectionLayer_TRT version 1
Plugin already registered - ::EfficientNMS_Explicit_TF_TRT version 1
Plugin already registered - ::EfficientNMS_Implicit_TF_TRT version 1
Plugin already registered - ::EfficientNMS_ONNX_TRT version 1
Plugin already registered - ::EfficientNMS_TRT version 1
Plugin already registered - ::FlattenConcat_TRT version 1
Plugin already registered - ::GenerateDetection_TRT version 1
Plugin already registered - ::GridAnchor_TRT version 1
Plugin already registered - ::GridAnchorRect_TRT version 1
Plugin already registered - ::InstanceNormalization_TRT version 1
Plugin already registered - ::InstanceNormalization_TRT version 2
Plugin already registered - ::LReLU_TRT version 1
Plugin already registered - ::ModulatedDeformConv2d version 1
Plugin already registered - ::MultilevelCropAndResize_TRT version 1
Plugin already registered - ::MultilevelProposeROI_TRT version 1
Plugin already registered - ::MultiscaleDeformableAttnPlugin_TRT version 1
Plugin already registered - ::NMSDynamic_TRT version 1
Plugin already registered - ::NMS_TRT version 1
Plugin already registered - ::Normalize_TRT version 1
Plugin already registered - ::PillarScatterPlugin version 1
Plugin already registered - ::PriorBox_TRT version 1
Plugin already registered - ::ProposalDynamic version 1
Plugin already registered - ::ProposalLayer_TRT version 1
Plugin already registered - ::Proposal version 1
Plugin already registered - ::PyramidROIAlign_TRT version 1
Plugin already registered - ::Region_TRT version 1
Plugin already registered - ::Reorg_TRT version 1
Plugin already registered - ::ResizeNearest_TRT version 1
Plugin already registered - ::ROIAlign_TRT version 1
Plugin already registered - ::RPROI_TRT version 1
Plugin already registered - ::ScatterND version 1
Plugin already registered - ::SpecialSlice_TRT version 1
Plugin already registered - ::Split version 1
Plugin already registered - ::VoxelGeneratorPlugin version 1
Adding network input: l_x_ with dtype: float32, dimensions: (1, 3, 224, 224)
Registering tensor: l_x_ for ONNX tensor: l_x_
Importing : patch_embed.proj.weight
Importing : patch_embed.proj.bias
Importing : pos_embed
Importing : blocks.0.norm1.weight
Importing : blocks.0.norm1.bias
Importing : blocks.0.attn.qkv.weight
Importing : blocks.0.attn.proj.weight
Importing : blocks.0.attn.proj.bias
Importing : blocks.0.norm2.weight
Importing : blocks.0.norm2.bias
Importing : blocks.0.mlp.fc1.weight
Importing : blocks.0.mlp.fc1.bias
Importing : blocks.0.mlp.fc2.weight
Importing : blocks.0.mlp.fc2.bias
Importing : blocks.1.norm1.weight
Importing : blocks.1.norm1.bias
Importing : blocks.1.attn.qkv.weight
Importing : blocks.1.attn.proj.weight
Importing : blocks.1.attn.proj.bias
Importing : blocks.1.norm2.weight
Importing : blocks.1.norm2.bias
Importing : blocks.1.mlp.fc1.weight
Importing : blocks.1.mlp.fc1.bias
Importing : blocks.1.mlp.fc2.weight
Importing : blocks.1.mlp.fc2.bias
Importing : blocks.2.norm1.weight
Importing : blocks.2.norm1.bias
Importing : blocks.2.attn.qkv.weight
Importing : blocks.2.attn.proj.weight
Importing : blocks.2.attn.proj.bias
Importing : blocks.2.norm2.weight
Importing : blocks.2.norm2.bias
Importing : blocks.2.mlp.fc1.weight
Importing : blocks.2.mlp.fc1.bias
Importing : blocks.2.mlp.fc2.weight
Importing : blocks.2.mlp.fc2.bias
Importing : blocks.3.norm1.weight
Importing : blocks.3.norm1.bias
Importing : blocks.3.attn.qkv.weight
Importing : blocks.3.attn.proj.weight
Importing : blocks.3.attn.proj.bias
Importing : blocks.3.norm2.weight
Importing : blocks.3.norm2.bias
Importing : blocks.3.mlp.fc1.weight
Importing : blocks.3.mlp.fc1.bias
Importing : blocks.3.mlp.fc2.weight
Importing : blocks.3.mlp.fc2.bias
Importing : blocks.4.norm1.weight
Importing : blocks.4.norm1.bias
Importing : blocks.4.attn.qkv.weight
Importing : blocks.4.attn.proj.weight
Importing : blocks.4.attn.proj.bias
Importing : blocks.4.norm2.weight
Importing : blocks.4.norm2.bias
Importing : blocks.4.mlp.fc1.weight
Importing : blocks.4.mlp.fc1.bias
Importing : blocks.4.mlp.fc2.weight
Importing : blocks.4.mlp.fc2.bias
Importing : blocks.5.norm1.weight
Importing : blocks.5.norm1.bias
Importing : blocks.5.attn.qkv.weight
Importing : blocks.5.attn.proj.weight
Importing : blocks.5.attn.proj.bias
Importing : blocks.5.norm2.weight
Importing : blocks.5.norm2.bias
Importing : blocks.5.mlp.fc1.weight
Importing : blocks.5.mlp.fc1.bias
Importing : blocks.5.mlp.fc2.weight
Importing : blocks.5.mlp.fc2.bias
Importing : blocks.6.norm1.weight
Importing : blocks.6.norm1.bias
Importing : blocks.6.attn.qkv.weight
Importing : blocks.6.attn.proj.weight
Importing : blocks.6.attn.proj.bias
Importing : blocks.6.norm2.weight
Importing : blocks.6.norm2.bias
Importing : blocks.6.mlp.fc1.weight
Importing : blocks.6.mlp.fc1.bias
Importing : blocks.6.mlp.fc2.weight
Importing : blocks.6.mlp.fc2.bias
Importing : blocks.7.norm1.weight
Importing : blocks.7.norm1.bias
Importing : blocks.7.attn.qkv.weight
Importing : blocks.7.attn.proj.weight
Importing : blocks.7.attn.proj.bias
Importing : blocks.7.norm2.weight
Importing : blocks.7.norm2.bias
Importing : blocks.7.mlp.fc1.weight
Importing : blocks.7.mlp.fc1.bias
Importing : blocks.7.mlp.fc2.weight
Importing : blocks.7.mlp.fc2.bias
Importing : blocks.8.norm1.weight
Importing : blocks.8.norm1.bias
Importing : blocks.8.attn.qkv.weight
Importing : blocks.8.attn.proj.weight
Importing : blocks.8.attn.proj.bias
Importing : blocks.8.norm2.weight
Importing : blocks.8.norm2.bias
Importing : blocks.8.mlp.fc1.weight
Importing : blocks.8.mlp.fc1.bias
Importing : blocks.8.mlp.fc2.weight
Importing : blocks.8.mlp.fc2.bias
Importing : blocks.9.norm1.weight
Importing : blocks.9.norm1.bias
Importing : blocks.9.attn.qkv.weight
Importing : blocks.9.attn.proj.weight
Importing : blocks.9.attn.proj.bias
Importing : blocks.9.norm2.weight
Importing : blocks.9.norm2.bias
Importing : blocks.9.mlp.fc1.weight
Importing : blocks.9.mlp.fc1.bias
Importing : blocks.9.mlp.fc2.weight
Importing : blocks.9.mlp.fc2.bias
Importing : blocks.10.norm1.weight
Importing : blocks.10.norm1.bias
Importing : blocks.10.attn.qkv.weight
Importing : blocks.10.attn.proj.weight
Importing : blocks.10.attn.proj.bias
Importing : blocks.10.norm2.weight
Importing : blocks.10.norm2.bias
Importing : blocks.10.mlp.fc1.weight
Importing : blocks.10.mlp.fc1.bias
Importing : blocks.10.mlp.fc2.weight
Importing : blocks.10.mlp.fc2.bias
Importing : blocks.11.norm1.weight
Importing : blocks.11.norm1.bias
Importing : blocks.11.attn.qkv.weight
Importing : blocks.11.attn.proj.weight
Importing : blocks.11.attn.proj.bias
Importing : blocks.11.norm2.weight
Importing : blocks.11.norm2.bias
Importing : blocks.11.mlp.fc1.weight
Importing : blocks.11.mlp.fc1.bias
Importing : blocks.11.mlp.fc2.weight
Importing : blocks.11.mlp.fc2.bias
Importing : norm.weight
Importing : norm.bias
Importing : feature.0.weight
Importing : feature.1.weight
Importing : feature.1.bias
Importing : feature.1.running_mean
Importing : feature.1.running_var
Importing : feature.2.weight
Importing : feature.3.weight
Importing : feature.3.bias
Importing : feature.3.running_mean
Importing : feature.3.running_var
Parsing node: unicom_vision_transformer_PatchEmbedding_patch_embed_1_1 [unicom_vision_transformer_PatchEmbedding_patch_embed_1]
Searching for input: l_x_
Searching for input: patch_embed.proj.weight
Searching for input: patch_embed.proj.bias
unicom_vision_transformer_PatchEmbedding_patch_embed_1_1 [unicom_vision_transformer_PatchEmbedding_patch_embed_1] inputs: [l_x_ -> (1, 3, 224, 224)[FLOAT]], [patch_embed.proj.weight -> (768, 3, 32, 32)[FLOAT]], [patch_embed.proj.bias -> (768)[FLOAT]],
No importer registered for op: unicom_vision_transformer_PatchEmbedding_patch_embed_1. Attempting to import as plugin.
Searching for plugin: unicom_vision_transformer_PatchEmbedding_patch_embed_1, plugin_version: 1, plugin_namespace:
Local registry did not find unicom_vision_transformer_PatchEmbedding_patch_embed_1 creator. Will try parent registry if enabled.
Global registry did not find unicom_vision_transformer_PatchEmbedding_patch_embed_1 creator. Will try parent registry if enabled.
3: getPluginCreator could not find plugin: unicom_vision_transformer_PatchEmbedding_patch_embed_1 version: 1
ModelImporter.cpp:768: While parsing node number 0 [unicom_vision_transformer_PatchEmbedding_patch_embed_1 -> "patch_embed_1"]:
ModelImporter.cpp:769: --- Begin node ---
ModelImporter.cpp:770: input: "l_x_"
input: "patch_embed.proj.weight"
input: "patch_embed.proj.bias"
output: "patch_embed_1"
name: "unicom_vision_transformer_PatchEmbedding_patch_embed_1_1"
op_type: "unicom_vision_transformer_PatchEmbedding_patch_embed_1"
doc_string: ""
domain: "pkg.unicom"
input: "patch_embed.proj.weight"
input: "patch_embed.proj.bias"
output: "patch_embed_1"
name: "unicom_vision_transformer_PatchEmbedding_patch_embed_1_1"
op_type: "unicom_vision_transformer_PatchEmbedding_patch_embed_1"
doc_string: ""
domain: "pkg.unicom"
[E] ModelImporter.cpp:771: --- End node ---
[E] ModelImporter.cpp:773: ERROR: builtin_op_importers.cpp:5403 In function importFallbackPluginImporter:
[E] ModelImporter.cpp:771: --- End node ---
[E] ModelImporter.cpp:773: ERROR: builtin_op_importers.cpp:5403 In function importFallbackPluginImporter:
[8] Assertion failed: creator && "Plugin not found, are the plugin name, version, and namespace correct?"
[8] Assertion failed: creator && "Plugin not found, are the plugin name, version, and namespace correct?"
[E] Failed to parse onnx file
[I] Finished parsing network model. Parse time: 4.99544
[E] Parsing model failed
[E] Failed to create engine from model or file.
[E] Engine set up failed
[E] Failed to parse onnx file
[I] Finished parsing network model. Parse time: 13.1481
[E] Parsing model failed
[E] Failed to create engine from model or file.
[E] Engine set up failed
The problem seems to arise from the PatchEmbedding
class here and it doesn't seem as if the model is using any extraordinary methods and layers that aren't convertible by TensorRT. Here's the class's source code:
class PatchEmbedding(nn.Module):
def __init__(self, input_size=224, patch_size=32, in_channels: int = 3, dim: int = 768):
super().__init__()
if isinstance(input_size, int):
input_size = (input_size, input_size)
if isinstance(patch_size, int):
patch_size = (patch_size, patch_size)
H = input_size[0] // patch_size[0]
W = input_size[1] // patch_size[1]
self.num_patches = H * W
self.proj = nn.Conv2d(
in_channels, dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
x = self.proj(x).flatten(2).transpose(1, 2)
return x
What should I do to make the model convertible to TensorRT?
## Environment
**TensorRT Version**: tensorrt_version_8_6_2_3
**GPU Type**: Jetson Orin Nano
**Nvidia Driver Version**:
**CUDA Version**: 12.2
**CUDNN Version**: 8.9.4.25-1+cuda12.2
**Operating System + Version**: Jetpack 6.0
**Python Version (if applicable)**: 3.10
**PyTorch Version (if applicable)**: 2.3.0
**ONNX Version (if applicable)**: 1.16.1
**onnxruntime-gpu Version (if applicable)**: 1.17.0
**onnxscript Version (if applicable)**: 0.1.0.dev20240721
UPDATE:
Running the code using the torch.onnx.export
instead of torch.onnx.dynamo_export
gives this error:
/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension:
warn(f"Failed to load image Python extension: {e}")
/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/torch/_dynamo/external_utils.py:36: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.
return fn(*args, **kwargs)
Traceback (most recent call last):
File "/home/jetson/HPS/Scripts_Utilities/ONNX/HPS_ExportModelToONNX.py", line 31, in <module>
torch.onnx.export(model, torch_input,onnx_model_path)
File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/torch/onnx/utils.py", line 516, in export
_export(
File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/torch/onnx/utils.py", line 1612, in _export
graph, params_dict, torch_out = _model_to_graph(
File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/torch/onnx/utils.py", line 1134, in _model_to_graph
graph, params, torch_out, module = _create_jit_graph(model, args)
File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/torch/onnx/utils.py", line 1010, in _create_jit_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args)
File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/torch/onnx/utils.py", line 914, in _trace_and_get_graph_from_model
trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/torch/jit/_trace.py", line 1315, in _get_trace_graph
outs = ONNXTracedModule(
File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/torch/jit/_trace.py", line 141, in forward
graph, out = torch._C._create_graph_by_tracing(
File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/torch/jit/_trace.py", line 132, in wrapper
outs.append(self.inner(*trace_inputs))
File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1522, in _slow_forward
result = self.forward(*input, **kwargs)
File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/unicom/vision_transformer.py", line 57, in forward
x = self.forward_features(x)
File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/unicom/vision_transformer.py", line 52, in forward_features
x = func(x)
File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1522, in _slow_forward
result = self.forward(*input, **kwargs)
File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/unicom/vision_transformer.py", line 122, in forward
return checkpoint(self.forward_impl, x)
File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/torch/_compile.py", line 24, in inner
return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 403, in _fn
return fn(*args, **kwargs)
File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 36, in inner
return fn(*args, **kwargs)
File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 481, in checkpoint
return CheckpointFunction.apply(function, preserve, *args)
File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/torch/autograd/function.py", line 571, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
RuntimeError: _Map_base::at
The fact that trtexec
is failing on PatchEmbedding
is definitely weird, considering the entire module is literally just a convolution. Absolutely no idea what is going wrong here.
Regarding your exception when using the legacy TorchScript-exporter (torch.onnx.export
), it fails in this case due to parts of the transformer being wrapped in torch.utils.checkpoint
, which apparently does not support tracing. If we remove the checkpointing by initializing VisionTransformer
with using_checkpoint=False
, the model exports without any errors.
import torch
from unicom.vision_transformer import VisionTransformer
model_name_fp16 = "FP16-ViT-B-32"
onnx_model_path = f"{model_name_fp16}.onnx"
model = VisionTransformer(
input_size=224,
patch_size=32,
in_channels=3,
dim=768,
embedding_size=512,
depth=12,
num_heads=12,
drop_path_rate=0.1,
using_checkpoint=False, # default value of True breaks torch.onnx.export
)
model.eval()
model = model.to('cuda')
torch_input = torch.randn(1, 3, 224, 224).to('cuda')
# using TorchScript export instead of TorchDynamo
onnx_program = torch.onnx.export(model, torch_input, onnx_model_path)
And the resulting ONNX model happily compiles into a TensorRT engine.
>>> trtexec --onnx=FP16-ViT-B-32.onnx --fp16
...
>>> [I] Engine built in 24.711 sec.
>>> &&&& PASSED TensorRT.trtexec [TensorRT v100200]
The above worked using torch==2.2.0
and TensorRT==10.2
.