Search code examples
javatensorflow

DataType 20 is not recognized in Java (version 1.15.0)


I am trying to perform prediction in Java using the trained model from Python. While the pipeline works fine in Python, I am trying to perform similar prediction in Java. I have named the input_1 in the model as "inputTensor" and output as "outputTensor". In Java, the name has been changed from "inputTensor" to "serving_default_inputTensor".

The test image has already been normalized and has single channel or gray only, no RGB.

I have converted the BufferedImage to Tensor. But got the error "DataType 20 is not recognized in Java (version 1.15.0)". The 1.15.0 is the tensorflow version in Java.

import java.awt.Graphics2D;
import java.awt.List;
import java.awt.RenderingHints;
import java.awt.image.BufferedImage;
import java.awt.image.Raster;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.util.Iterator;

import javax.imageio.ImageIO;
// tensorflow libraries
import org.tensorflow.Graph;
import org.tensorflow.Operation;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;

@SuppressWarnings("unused")
public class JavaTensorflowPredict {
    public static void main(String[] args) throws IOException {
        SavedModelBundle theModel = SavedModelBundle.load("mymodel/", "serve");
        Session sess = theModel.session();
        Graph graph = theModel.graph();
        
        File inputImage = new File("../test_images/test_01.png");
        BufferedImage image1 = ImageIO.read(inputImage);
        
        System.out.println(TensorFlow.version());  
        Tensor<Float> inputTensor1 = convertBufferedImageToTensor(image1, 128, 128);
        System.out.println(inputTensor1.dataType().toString());
        Iterator<Operation> iterOpts = graph.operations();
        while (iterOpts.hasNext()) {
            Operation oprt = iterOpts.next();
            System.out.println(oprt.name()); // "15" entfernt 
        } 
        System.out.println("...predicting...\n");
        java.util.List<Tensor<?>> y = sess.runner().feed("serving_default_inputTensor", inputTensor1).fetch("outputTensor_1/kernel").run();
        System.out.println("...done...\n"); 
    }
    public static BufferedImage resize(BufferedImage img, int width, int height) {
        // obtained from https://github.com/mstritt/orbit-image-analysis
        //int type = img.getType()>0?img.getType():BufferedImage.TYPE_INT_RGB;
        int type = BufferedImage.TYPE_INT_RGB;
        //BufferedImage resizedImage = new BufferedImage(roundP2(width), roundP2(height), type);
        BufferedImage resizedImage = new BufferedImage(width, height, type);
        Graphics2D g = resizedImage.createGraphics();
        g.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR);
        g.drawImage(img, 0, 0, width, height, null);
        g.dispose();
        return resizedImage;
    }
    public static Tensor<Float> convertBufferedImageToTensor(BufferedImage image, int targetWidth, int targetHeight) {
        // obtained from https://github.com/mstritt/orbit-image-analysis
        //if (image.getWidth()!=DESIRED_SIZE || image.getHeight()!=DESIRED_SIZE)
        {
            // also make it an RGB image
            // image = resize(image, targetWidth, targetHeight);
            // image = resize(image,image.getWidth(), image.getHeight());
        }
        int width = image.getWidth();
        int height = image.getHeight();
        Raster r = image.getRaster();
        int[] rgb = new int[1];
        //int[] data = new int[width * height];
        //image.getRGB(0, 0, width, height, data, 0, width);
        float[][][][] rgbArray = new float[1][height][width][1];
        for (int i = 0; i < height; i++) {
            for (int j = 0; j < width; j++) {
                rgb = r.getPixel(j,i,rgb);
                rgbArray[0][i][j][0] = rgb[0];
            }
        }
        return Tensor.create(rgbArray, Float.class);
    }
}

and here is the output :

Exception in thread "main" java.lang.IllegalArgumentException: DataType 20 is not recognized in Java (version 1.15.0)
at org.tensorflow.DataType.fromC(DataType.java:85)
at org.tensorflow.Tensor.fromHandle(Tensor.java:540)
at org.tensorflow.Session$Runner.runHelper(Session.java:343)
at org.tensorflow.Session$Runner.run(Session.java:276)
at tftest2.main(tftest2.java:39)

Of note: I use the following script to do the prediction in Python

import tensorflow as tf
import cv2
import numpy as np

the_model = tf.keras.models.laod_model("mymodel")
image1 = cv2.imread("test_01.png",0)
# convert image from (128,128) to (1,128,128,1)
image1 = np.reshape(image1, (1,)+image1.shape+(1,))
predict = the_model(image1, training = True)

Solution

  • The answer is as following; I have used Python Tensorflow version 2.4.1 for training. Then, I used TF1 in Java (version 1.15.0) to load the model. However, using TF2 (in Java, tensorflow-core-platform, version 0.3.1) solves the problem. Because Java tesnorflow-core-platform 0.3.0 can be used to load models from tensforflow 2.4.1 and later in Python.