(note, I've resolved my problem and posted the code at the bottom)
I'm playing around with TensorFlow and the backend processing must take place in Java. I've taken one of the models from the https://developers.google.com/machine-learning/crash-course and saved it with tf.saved_model.save(my_model,"house_price_median_income") (using a docker container). I copied the model off and loaded it into Java (using the 2.0 stuff built from source because I'm on windows). I can load the model and run it:
try (SavedModelBundle model = SavedModelBundle.load("./house_price_median_income", "serve")) {
try (Session session = model.session()) {
Session.Runner runner = session.runner();
float[][] in = new float[][]{ {2.1518f} } ;
Tensor<?> jack = Tensor.create(in);
runner.feed("serving_default_layer1_input", jack);
float[][] probabilities = runner.fetch("StatefulPartitionedCall").run().get(0).copyTo(new float[1][1]);
for (int i = 0; i < probabilities.length; ++i) {
System.out.println(String.format("-- Input #%d", i));
for (int j = 0; j < probabilities[i].length; ++j) {
System.out.println(String.format("Class %d - %f", i, probabilities[i][j]));
}
}
}
}
The above is hardcoded to an input and output but I want to be able to read the model and provide some information so the end-user can select the input and output, etc.
I can get the inputs and outputs with the python command: saved_model_cli show --dir ./house_price_median_income --all
What I want to do it get the inputs and outputs via Java so my code doesn't need to execute python script to get them. I can get operations via:
Graph graph = model.graph();
Iterator<Operation> itr = graph.operations();
while (itr.hasNext()) {
GraphOperation e = (GraphOperation)itr.next();
System.out.println(e);
And this outputs both the inputs and outputs as "operations" BUT how do I know that it is an input and\or an output? The python tool uses the SignatureDef but that doesn't seem to appear in the TensorFlow 2.0 java stuff at all. Am I missing something obvious or is it just missing from TensforFlow 2.0 Java library?
NOTE, I've sorted my issue with the answer help below. Here is my full bit of code in case somebody would like it in the future. Note this is TF 2.0 and uses the SNAPSHOT mentioned below. I make a few assumptions but it shows how to pull the input and output and then use them to run a model
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.exceptions.TensorFlowException;
import org.tensorflow.Session.Run;
import org.tensorflow.Graph;
import org.tensorflow.Operation;
import org.tensorflow.Output;
import org.tensorflow.GraphOperation;
import org.tensorflow.proto.framework.SignatureDef;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.tensorflow.proto.framework.MetaGraphDef;
import java.util.Map;
import org.tensorflow.proto.framework.TensorInfo;
import org.tensorflow.types.TFloat32;
import org.tensorflow.tools.Shape;
import java.nio.FloatBuffer;
import org.tensorflow.tools.buffer.DataBuffers;
import org.tensorflow.tools.ndarray.FloatNdArray;
import org.tensorflow.tools.ndarray.StdArrays;
import org.tensorflow.proto.framework.TensorInfo;
public class v2tensor {
public static void main(String[] args) {
try (SavedModelBundle savedModel = SavedModelBundle.load("./house_price_median_income", "serve")) {
SignatureDef modelInfo = savedModel.metaGraphDef().getSignatureDefMap().get("serving_default");
TensorInfo input1 = null;
TensorInfo output1 = null;
Map<String, TensorInfo> inputs = modelInfo.getInputsMap();
for(Map.Entry<String, TensorInfo> input : inputs.entrySet()) {
if (input1 == null) {
input1 = input.getValue();
System.out.println(input1.getName());
}
System.out.println(input);
}
Map<String, TensorInfo> outputs = modelInfo.getOutputsMap();
for(Map.Entry<String, TensorInfo> output : outputs.entrySet()) {
if (output1 == null) {
output1=output.getValue();
}
System.out.println(output);
}
try (Session session = savedModel.session()) {
Session.Runner runner = session.runner();
FloatNdArray matrix = StdArrays.ndCopyOf(new float[][]{ { 2.1518f } } );
try (Tensor<TFloat32> jack = TFloat32.tensorOf(matrix) ) {
runner.feed(input1.getName(), jack);
try ( Tensor<TFloat32> rezz = runner.fetch(output1.getName()).run().get(0).expect(TFloat32.DTYPE) ) {
TFloat32 data = rezz.data();
data.scalars().forEachIndexed((i, s) -> {
System.out.println(s.getFloat());
} );
}
}
}
} catch (TensorFlowException ex) {
ex.printStackTrace();
}
}
}
What you need to do is to read the SavedModelBundle
metadata as a MetaGraphDef
, from there you can retrieve input and output names from the SignatureDef
, like in Python.
In TF Java 1.* (i.e. the client you are using in your example), the proto definitions are not available out-of-the-box from the tensorflow
artifact, you need to add a dependency to org.tensorflow:proto
as well and deserialize the result of SavedModelBundle.metaGraphDef()
into a MetaGraphDef
proto.
In TF Java 2.* (the new client actually only available as snapshots from here), the protos are present right away so you can simply call this line to retrieve the right SignatureDef
:
savedModel.metaGraphDef().signatureDefMap.getValue("serving_default")