Search code examples
javaobject-detectionyoloonnxyolov5

Issue with Object Detection Results in Java using YOLOv5 ONNX Model


I trained a neural network for object detection using YOLOv5 and exported it to the ONNX format, as I need to use it in a Java application. However, I am encountering issues with incorrect class IDs and bounding boxes in my Java code, while the Python implementation works as expected.

Issues Encountered:

Incorrect Class IDs: All of the detections in the Java code are returning class ID 0, which is not expected. In the Python code, I receive valid class IDs corresponding to the detected objects.

Bounding Box Offsets: The bounding boxes appear to have a symmetrical offset, being shifted too far positively on the x-axis and too far down on the y-axis. This results in inaccurate positioning of the detected objects in the output image.

Wrong interpretation of Model Output?

(NMS not yet implemented in JAVA-Code)

For Reference here are some infos about the model:

Model Properties:
Format: ONNX v8
Producer: PyTorch 2.6.0
Version: 0
Imports: ai.onnx v17
Graph: main_graph

Metadata:
Stride: 32
Class Names: {0: '1', 1: '2', 2: '3', 3: '4', 4: '5', 5: '6', 6: '7', 7: '8', 8: '9', 9: '10'}
Inputs:

Name: images
Tensor: float32\[1,3,640,640\] (that means 1 image, RGB, dimensions (640,640))
Outputs:

Name: output0
Tensor: float32\[1,25200,15\] (not fully sure what this means)

Now to the code(the same test image is used in both cases):

This is the working python code that gives the expected results:

# !pip install torch torchvision pillow

import torch
from PIL import Image

# Path to the YOLOv5 repository
path_to_yolo_library = '/content/yolov5'  
onnx_path = '/content/best.onnx'  
image_path = '/content/img101.png'  

# Import the YOLOv5 model from the local path
model = torch.hub.load(path_to_yolo_library, 'custom', path=onnx_path, source='local') 

# Load and preprocess the image
img = Image.open(image_path)  # Load image as a PIL image
img = img.resize((640, 640))  # Resize the image to fit YOLO input size (640x640)

# Inference (includes NMS)
results = model(img, size=640)  # Inference with NMS

# Results
results.print()  # Print the results (detections, classes, confidence)
results.show()   # Show the image with bounding boxes
results.save()   # Save the result images

# Data: Print the bounding boxes, confidence scores, and class ids
print('\n', results.xyxy[0])  # Print predictions in the format (x1, y1, x2, y2, confidence, class)

Output: (includes correct visualization)

image 1/1: 640x640 2 1s, 2 2s, 1 3, 2 4s, 1 5, 1 6, 1 7
Speed: 16.7ms pre-process, 407.8ms inference, 6.0ms NMS per image at shape (1, 3, 640, 640)
tensor([[4.49145e+01, 1.94186e+02, 1.14293e+02, 3.11326e+02, 8.03208e-01, 0.00000e+00],
        [4.44819e+01, 3.47444e+02, 1.18352e+02, 4.73753e+02, 7.96138e-01, 1.00000e+00],
        [3.68868e+02, 2.70193e+01, 4.38986e+02, 1.55611e+02, 7.92952e-01, 1.00000e+00],
        [4.62871e+01, 3.24609e+01, 1.15780e+02, 1.50192e+02, 7.83159e-01, 0.00000e+00],
        [3.47603e+02, 4.95154e+02, 4.30347e+02, 6.35069e+02, 7.63681e-01, 4.00000e+00],....

Since this worked very well and accurately i tried translating it into java:

import java.awt.BasicStroke;
import java.awt.Color;
import java.awt.Graphics;
import java.awt.image.BufferedImage;
import java.io.File;
import java.nio.FloatBuffer;
import java.util.Collections;
import java.util.Map;

import javax.imageio.ImageIO;

import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtSession;

public class YOLOv5ONNXJava {
    public static void main(String[] args) {
        try {
            // Load ONNX model
            String modelPath = "...best.onnx";
            OrtEnvironment env = OrtEnvironment.getEnvironment();
            OrtSession session = env.createSession(modelPath, new OrtSession.SessionOptions());

            // Load Image
            BufferedImage image = ImageIO.read(new File("...img101.png"));
            int origWidth = image.getWidth();
            int origHeight = image.getHeight();
            int inputSize = 640;

            BufferedImage resizedImage = resizeImage(image, inputSize, inputSize);

            // Convert Image to Tensor
            float[] inputTensor = preprocessImage(resizedImage, inputSize);
            long[] shape = {1, 3, inputSize, inputSize}; // Batch size 1, RGB channels
            OnnxTensor tensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(inputTensor), shape);

            // Run inference
            Map<String, OnnxTensor> inputMap = Collections.singletonMap(session.getInputNames().iterator().next(), tensor);
            OrtSession.Result result = session.run(inputMap);

            // Process Output
            float[][][] outputData = (float[][][]) result.get(0).getValue();
            float[][] detections = outputData[0]; // Extract first batch

            // Post-process detections and draw rectangles
            postProcess(detections, origWidth, origHeight, image);

            // Cleanup
            session.close();
            tensor.close();
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    // Resize image function
    public static BufferedImage resizeImage(BufferedImage originalImage, int width, int height) {
        BufferedImage resizedImage = new BufferedImage(width, height, BufferedImage.TYPE_INT_RGB);
        resizedImage.getGraphics().drawImage(originalImage, 0, 0, width, height, null);
        return resizedImage;
    }

    // Preprocess image into tensor format
    public static float[] preprocessImage(BufferedImage image, int inputSize) {
        float[] tensor = new float[3 * inputSize * inputSize]; // RGB channels
        int[] rgbArray = image.getRGB(0, 0, inputSize, inputSize, null, 0, inputSize);

        for (int i = 0; i < inputSize * inputSize; i++) {
            int pixel = rgbArray[i];
            tensor[i] = ((pixel >> 16) & 0xFF) / 255.0f; // Red
            tensor[i + inputSize * inputSize] = ((pixel >> 8) & 0xFF) / 255.0f; // Green
            tensor[i + 2 * inputSize * inputSize] = (pixel & 0xFF) / 255.0f; // Blue
        }
        return tensor;
    }

 // Updated postProcess method to draw rectangles and labels
    public static void postProcess(float[][] detections, int origWidth, int origHeight, BufferedImage originalImage) {
        System.out.println("\nDetections:");

        int inputSize = 640; // YOLOv5 default input size
        float gain = Math.min((float) inputSize / origWidth, (float) inputSize / origHeight);
        float padX = (inputSize - origWidth * gain) / 2;
        float padY = (inputSize - origHeight * gain) / 2;

        // Create a copy of the original image to draw on
        BufferedImage outputImage = new BufferedImage(origWidth, origHeight, BufferedImage.TYPE_INT_RGB);
        Graphics g = outputImage.getGraphics();
        g.drawImage(originalImage, 0, 0, null);

        for (float[] row : detections) {
            if (row.length < 6) continue;

            // Extract bounding box values
            float x1 = row[0], y1 = row[1], x2 = row[2], y2 = row[3];
            float confidence = row[4];
            int classId = (int) row[5]; // Extract raw class ID

            // Apply YOLOv5 scaling transformation
           x1 = (x1 - padX) / gain; // Adjust x1
           y1 = (y1 - padY) / gain; // Adjust y1
            x2 = (x2 - padX) / gain; // Adjust x2
            y2 = (y2 - padY) / gain; // Adjust y2

            // Clip bounding boxes to image boundaries
           

            if (confidence > 0.5) {
                System.out.printf("BBox: [%.2f, %.2f, %.2f, %.2f], Confidence: %.2f, Class ID: %d%n",
                        x1, y1, x2, y2, confidence, classId);
                
                // Draw the bounding box
                g.setColor(Color.RED); // Set color for bounding box
                g.drawRect((int) x1, (int)y1,(int)x2,(int)y2);
                
                // Draw the label with confidence score
                g.setColor(Color.WHITE);
                g.drawString(String.format("ID: %d Conf: %.2f", classId, confidence), (int) Math.round(x1), (int) Math.round(y1 - 10));
            }
        }
        g.dispose();

        // Save the output image
        try {
            ImageIO.write(outputImage, "jpg", new File("output.jpg"));
            System.out.println("Output image saved as output.jpg");
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

}
Output: 
(rectangles have a wrong offset and i dont know why the class ID are 0 everywhere)
BBox: [79,97, 92,82, 56,06, 119,68], Confidence: 0,70, Class ID: 0
BBox: [81,12, 92,92, 57,66, 115,62], Confidence: 0,76, Class ID: 0
BBox: [405,68, 88,88, 56,72, 119,84], Confidence: 0,67, Class ID: 0
BBox: [81,99, 94,62, 61,15, 112,69], Confidence: 0,51, Class ID: 0

Conclusion: I am looking for guidance on how to resolve these discrepancies between the Python and Java implementations. Any insights on adjusting the preprocessing steps, interpreting the ONNX model output, or debugging the bounding box coordinates would be greatly appreciated.


Solution

  • I haven't worked with yolov5 for some time but here is what i remember. Since you have 15 features for each box, this how they are interpreted:
    [x, y, w, h, objectness, class1_conf, class2_conf, ..., class10_conf]

    int classId = (int) row[5];
    

    this is wrong. if you want to get class id for each box you need to find the max conf for your 10 classes and find the index of that max. this will be the class id of that box.