Search code examples
deeplearning4jdl4jnd4j

DL4J - Is there a way to restrict the prediction of a model


I trained a Mnist model with DL4J. When I use this model in inference mode:

INDArray prediction = myModel.output(myINDArrayImage);

That gives me a prediction in an INDArray, it works properly. The size of this INDArray is equal to number of output on my OutputLayer model.

Is there a way to restrict prediction to a character base? i.e. somethings like this:

INDArray prediction = myModel.output(myINDArrayImage, charactersPossible);

Where charactersPossible is the list of possible output indexes?


Solution

  • You can create an INDArray (using Nd4j.create(double[])) with 1.0 for possible characters and 0.0 for not-possible characters. Then multiply that with the prediction INDArray, and then Nd4j.argMax the result.