I am trying to perform prediction in Java using the trained model from Python. While the pipeline works fine in Python, I am trying to perform similar prediction in Java. I have named the input_1 in the model as "inputTensor" and output as "outputTensor". In Java, the name has been changed from "inputTensor" to "serving_default_inputTensor".
The test image has already been normalized and has single channel or gray only, no RGB.
I have converted the BufferedImage to Tensor. But got the error "DataType 20 is not recognized in Java (version 1.15.0)"
. The 1.15.0
is the tensorflow version in Java.
import java.awt.Graphics2D;
import java.awt.List;
import java.awt.RenderingHints;
import java.awt.image.BufferedImage;
import java.awt.image.Raster;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.util.Iterator;
import javax.imageio.ImageIO;
// tensorflow libraries
import org.tensorflow.Graph;
import org.tensorflow.Operation;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;
@SuppressWarnings("unused")
public class JavaTensorflowPredict {
public static void main(String[] args) throws IOException {
SavedModelBundle theModel = SavedModelBundle.load("mymodel/", "serve");
Session sess = theModel.session();
Graph graph = theModel.graph();
File inputImage = new File("../test_images/test_01.png");
BufferedImage image1 = ImageIO.read(inputImage);
System.out.println(TensorFlow.version());
Tensor<Float> inputTensor1 = convertBufferedImageToTensor(image1, 128, 128);
System.out.println(inputTensor1.dataType().toString());
Iterator<Operation> iterOpts = graph.operations();
while (iterOpts.hasNext()) {
Operation oprt = iterOpts.next();
System.out.println(oprt.name()); // "15" entfernt
}
System.out.println("...predicting...\n");
java.util.List<Tensor<?>> y = sess.runner().feed("serving_default_inputTensor", inputTensor1).fetch("outputTensor_1/kernel").run();
System.out.println("...done...\n");
}
public static BufferedImage resize(BufferedImage img, int width, int height) {
// obtained from https://github.com/mstritt/orbit-image-analysis
//int type = img.getType()>0?img.getType():BufferedImage.TYPE_INT_RGB;
int type = BufferedImage.TYPE_INT_RGB;
//BufferedImage resizedImage = new BufferedImage(roundP2(width), roundP2(height), type);
BufferedImage resizedImage = new BufferedImage(width, height, type);
Graphics2D g = resizedImage.createGraphics();
g.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR);
g.drawImage(img, 0, 0, width, height, null);
g.dispose();
return resizedImage;
}
public static Tensor<Float> convertBufferedImageToTensor(BufferedImage image, int targetWidth, int targetHeight) {
// obtained from https://github.com/mstritt/orbit-image-analysis
//if (image.getWidth()!=DESIRED_SIZE || image.getHeight()!=DESIRED_SIZE)
{
// also make it an RGB image
// image = resize(image, targetWidth, targetHeight);
// image = resize(image,image.getWidth(), image.getHeight());
}
int width = image.getWidth();
int height = image.getHeight();
Raster r = image.getRaster();
int[] rgb = new int[1];
//int[] data = new int[width * height];
//image.getRGB(0, 0, width, height, data, 0, width);
float[][][][] rgbArray = new float[1][height][width][1];
for (int i = 0; i < height; i++) {
for (int j = 0; j < width; j++) {
rgb = r.getPixel(j,i,rgb);
rgbArray[0][i][j][0] = rgb[0];
}
}
return Tensor.create(rgbArray, Float.class);
}
}
and here is the output :
Exception in thread "main" java.lang.IllegalArgumentException: DataType 20 is not recognized in Java (version 1.15.0)
at org.tensorflow.DataType.fromC(DataType.java:85)
at org.tensorflow.Tensor.fromHandle(Tensor.java:540)
at org.tensorflow.Session$Runner.runHelper(Session.java:343)
at org.tensorflow.Session$Runner.run(Session.java:276)
at tftest2.main(tftest2.java:39)
Of note: I use the following script to do the prediction in Python
import tensorflow as tf
import cv2
import numpy as np
the_model = tf.keras.models.laod_model("mymodel")
image1 = cv2.imread("test_01.png",0)
# convert image from (128,128) to (1,128,128,1)
image1 = np.reshape(image1, (1,)+image1.shape+(1,))
predict = the_model(image1, training = True)
The answer is as following; I have used Python Tensorflow version 2.4.1 for training. Then, I used TF1 in Java (version 1.15.0) to load the model. However, using TF2 (in Java, tensorflow-core-platform, version 0.3.1) solves the problem. Because Java tesnorflow-core-platform 0.3.0 can be used to load models from tensforflow 2.4.1 and later in Python.