I trained a neural network for object detection using YOLOv5 and exported it to the ONNX format, as I need to use it in a Java application. However, I am encountering issues with incorrect class IDs and bounding boxes in my Java code, while the Python implementation works as expected.
Issues Encountered:
Incorrect Class IDs: All of the detections in the Java code are returning class ID 0, which is not expected. In the Python code, I receive valid class IDs corresponding to the detected objects.
Bounding Box Offsets: The bounding boxes appear to have a symmetrical offset, being shifted too far positively on the x-axis and too far down on the y-axis. This results in inaccurate positioning of the detected objects in the output image.
Wrong interpretation of Model Output?
(NMS not yet implemented in JAVA-Code)
For Reference here are some infos about the model:
Model Properties:
Format: ONNX v8
Producer: PyTorch 2.6.0
Version: 0
Imports: ai.onnx v17
Graph: main_graph
Metadata:
Stride: 32
Class Names: {0: '1', 1: '2', 2: '3', 3: '4', 4: '5', 5: '6', 6: '7', 7: '8', 8: '9', 9: '10'}
Inputs:
Name: images
Tensor: float32\[1,3,640,640\] (that means 1 image, RGB, dimensions (640,640))
Outputs:
Name: output0
Tensor: float32\[1,25200,15\] (not fully sure what this means)
Now to the code(the same test image is used in both cases):
This is the working python code that gives the expected results:
# !pip install torch torchvision pillow
import torch
from PIL import Image
# Path to the YOLOv5 repository
path_to_yolo_library = '/content/yolov5'
onnx_path = '/content/best.onnx'
image_path = '/content/img101.png'
# Import the YOLOv5 model from the local path
model = torch.hub.load(path_to_yolo_library, 'custom', path=onnx_path, source='local')
# Load and preprocess the image
img = Image.open(image_path) # Load image as a PIL image
img = img.resize((640, 640)) # Resize the image to fit YOLO input size (640x640)
# Inference (includes NMS)
results = model(img, size=640) # Inference with NMS
# Results
results.print() # Print the results (detections, classes, confidence)
results.show() # Show the image with bounding boxes
results.save() # Save the result images
# Data: Print the bounding boxes, confidence scores, and class ids
print('\n', results.xyxy[0]) # Print predictions in the format (x1, y1, x2, y2, confidence, class)
Output: (includes correct visualization)
image 1/1: 640x640 2 1s, 2 2s, 1 3, 2 4s, 1 5, 1 6, 1 7
Speed: 16.7ms pre-process, 407.8ms inference, 6.0ms NMS per image at shape (1, 3, 640, 640)
tensor([[4.49145e+01, 1.94186e+02, 1.14293e+02, 3.11326e+02, 8.03208e-01, 0.00000e+00],
[4.44819e+01, 3.47444e+02, 1.18352e+02, 4.73753e+02, 7.96138e-01, 1.00000e+00],
[3.68868e+02, 2.70193e+01, 4.38986e+02, 1.55611e+02, 7.92952e-01, 1.00000e+00],
[4.62871e+01, 3.24609e+01, 1.15780e+02, 1.50192e+02, 7.83159e-01, 0.00000e+00],
[3.47603e+02, 4.95154e+02, 4.30347e+02, 6.35069e+02, 7.63681e-01, 4.00000e+00],....
Since this worked very well and accurately i tried translating it into java:
import java.awt.BasicStroke;
import java.awt.Color;
import java.awt.Graphics;
import java.awt.image.BufferedImage;
import java.io.File;
import java.nio.FloatBuffer;
import java.util.Collections;
import java.util.Map;
import javax.imageio.ImageIO;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtSession;
public class YOLOv5ONNXJava {
public static void main(String[] args) {
try {
// Load ONNX model
String modelPath = "...best.onnx";
OrtEnvironment env = OrtEnvironment.getEnvironment();
OrtSession session = env.createSession(modelPath, new OrtSession.SessionOptions());
// Load Image
BufferedImage image = ImageIO.read(new File("...img101.png"));
int origWidth = image.getWidth();
int origHeight = image.getHeight();
int inputSize = 640;
BufferedImage resizedImage = resizeImage(image, inputSize, inputSize);
// Convert Image to Tensor
float[] inputTensor = preprocessImage(resizedImage, inputSize);
long[] shape = {1, 3, inputSize, inputSize}; // Batch size 1, RGB channels
OnnxTensor tensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(inputTensor), shape);
// Run inference
Map<String, OnnxTensor> inputMap = Collections.singletonMap(session.getInputNames().iterator().next(), tensor);
OrtSession.Result result = session.run(inputMap);
// Process Output
float[][][] outputData = (float[][][]) result.get(0).getValue();
float[][] detections = outputData[0]; // Extract first batch
// Post-process detections and draw rectangles
postProcess(detections, origWidth, origHeight, image);
// Cleanup
session.close();
tensor.close();
} catch (Exception e) {
e.printStackTrace();
}
}
// Resize image function
public static BufferedImage resizeImage(BufferedImage originalImage, int width, int height) {
BufferedImage resizedImage = new BufferedImage(width, height, BufferedImage.TYPE_INT_RGB);
resizedImage.getGraphics().drawImage(originalImage, 0, 0, width, height, null);
return resizedImage;
}
// Preprocess image into tensor format
public static float[] preprocessImage(BufferedImage image, int inputSize) {
float[] tensor = new float[3 * inputSize * inputSize]; // RGB channels
int[] rgbArray = image.getRGB(0, 0, inputSize, inputSize, null, 0, inputSize);
for (int i = 0; i < inputSize * inputSize; i++) {
int pixel = rgbArray[i];
tensor[i] = ((pixel >> 16) & 0xFF) / 255.0f; // Red
tensor[i + inputSize * inputSize] = ((pixel >> 8) & 0xFF) / 255.0f; // Green
tensor[i + 2 * inputSize * inputSize] = (pixel & 0xFF) / 255.0f; // Blue
}
return tensor;
}
// Updated postProcess method to draw rectangles and labels
public static void postProcess(float[][] detections, int origWidth, int origHeight, BufferedImage originalImage) {
System.out.println("\nDetections:");
int inputSize = 640; // YOLOv5 default input size
float gain = Math.min((float) inputSize / origWidth, (float) inputSize / origHeight);
float padX = (inputSize - origWidth * gain) / 2;
float padY = (inputSize - origHeight * gain) / 2;
// Create a copy of the original image to draw on
BufferedImage outputImage = new BufferedImage(origWidth, origHeight, BufferedImage.TYPE_INT_RGB);
Graphics g = outputImage.getGraphics();
g.drawImage(originalImage, 0, 0, null);
for (float[] row : detections) {
if (row.length < 6) continue;
// Extract bounding box values
float x1 = row[0], y1 = row[1], x2 = row[2], y2 = row[3];
float confidence = row[4];
int classId = (int) row[5]; // Extract raw class ID
// Apply YOLOv5 scaling transformation
x1 = (x1 - padX) / gain; // Adjust x1
y1 = (y1 - padY) / gain; // Adjust y1
x2 = (x2 - padX) / gain; // Adjust x2
y2 = (y2 - padY) / gain; // Adjust y2
// Clip bounding boxes to image boundaries
if (confidence > 0.5) {
System.out.printf("BBox: [%.2f, %.2f, %.2f, %.2f], Confidence: %.2f, Class ID: %d%n",
x1, y1, x2, y2, confidence, classId);
// Draw the bounding box
g.setColor(Color.RED); // Set color for bounding box
g.drawRect((int) x1, (int)y1,(int)x2,(int)y2);
// Draw the label with confidence score
g.setColor(Color.WHITE);
g.drawString(String.format("ID: %d Conf: %.2f", classId, confidence), (int) Math.round(x1), (int) Math.round(y1 - 10));
}
}
g.dispose();
// Save the output image
try {
ImageIO.write(outputImage, "jpg", new File("output.jpg"));
System.out.println("Output image saved as output.jpg");
} catch (Exception e) {
e.printStackTrace();
}
}
}
Output:
(rectangles have a wrong offset and i dont know why the class ID are 0 everywhere)
BBox: [79,97, 92,82, 56,06, 119,68], Confidence: 0,70, Class ID: 0
BBox: [81,12, 92,92, 57,66, 115,62], Confidence: 0,76, Class ID: 0
BBox: [405,68, 88,88, 56,72, 119,84], Confidence: 0,67, Class ID: 0
BBox: [81,99, 94,62, 61,15, 112,69], Confidence: 0,51, Class ID: 0
Conclusion: I am looking for guidance on how to resolve these discrepancies between the Python and Java implementations. Any insights on adjusting the preprocessing steps, interpreting the ONNX model output, or debugging the bounding box coordinates would be greatly appreciated.
I haven't worked with yolov5 for some time but here is what i remember. Since you have 15 features for each box, this how they are interpreted:
[x, y, w, h, objectness, class1_conf, class2_conf, ..., class10_conf]
int classId = (int) row[5];
this is wrong. if you want to get class id for each box you need to find the max conf for your 10 classes and find the index of that max. this will be the class id of that box.