Search code examples
javadeeplearning4jdl4j

Deeplearning4j - how to use saved model?


I'm studing Deeplearning4j (ver. 1.0.0-M1.1) for building neural networks.

I use IrisClassifier from Deeplearning4j as an example.

//First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
int numLinesToSkip = 0;
char delimiter = ',';
RecordReader recordReader = new CSVRecordReader(numLinesToSkip,delimiter);
recordReader.initialize(new FileSplit(new File(DownloaderUtility.IRISDATA.Download(),"iris.txt")));

//Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
int labelIndex = 4;     //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row
int numClasses = 3;     //3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2
int batchSize = 150;    //Iris data set: 150 examples total. We are loading all of them into one DataSet (not recommended for large data sets)

DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader,batchSize,labelIndex,numClasses);
DataSet allData = iterator.next();
allData.shuffle();
SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65);  //Use 65% of data for training

DataSet trainingData = testAndTrain.getTrain();
DataSet testData = testAndTrain.getTest();

//We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(trainingData);           //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
normalizer.transform(trainingData);     //Apply normalization to the training data
normalizer.transform(testData);         //Apply normalization to the test data. This is using statistics calculated from the *training* set

final int numInputs = 4;
int outputNum = 3;
long seed = 6;

log.info("Build model....");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
    .seed(seed)
    .activation(Activation.TANH)
    .weightInit(WeightInit.XAVIER)
    .updater(new Sgd(0.1))
    .l2(1e-4)
    .list()
    .layer(new DenseLayer.Builder().nIn(numInputs).nOut(3)
        .build())
    .layer(new DenseLayer.Builder().nIn(3).nOut(3)
        .build())
    .layer( new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
        .activation(Activation.SOFTMAX) //Override the global TANH activation with softmax for this layer
        .nIn(3).nOut(outputNum).build())
    .build();

//run the model
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
//record score once every 100 iterations
model.setListeners(new ScoreIterationListener(100));

for(int i=0; i<1000; i++ ) {
    model.fit(trainingData);
}

//evaluate the model on the test set
Evaluation eval = new Evaluation(3);
INDArray output = model.output(testData.getFeatures());

eval.eval(testData.getLabels(), output);
log.info(eval.stats());

Inputs for training model looks like:

5.1,3.5,1.4,0.2,0
...
7.0,3.2,4.7,1.4,1
...
6.3,3.3,6.0,2.5,2

where the last item is a class for set inputs.

It works great, trains the model and tests.

Now I want to use trained model to predict classes of new inputs, but don't understang how to do it.

Ok, I can save the model, and load again:

// Save the Model
File locationToSave = new File("C:/Projects/deeplearning4j/trained_iris_model.zip");
ModelSerializer.writeModel(model, locationToSave, false);

// Open the model
File locationToLoad = new File("C:/Projects/deeplearning4j/trained_iris_model.zip");
MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(locationToLoad);

Next, I load as example the same data as used for training, but without classes.

5.1,3.5,1.4,0.2
...
7.0,3.2,4.7,1.4
...
6.3,3.3,6.0,2.5

int numLinesToSkip = 0;
char delimiter = ',';
CSVRecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter);  //skip no lines at the top - i.e. no header
recordReader.initialize(new FileSplit(new File("C:/Projects/deeplearning4j/iris-to-predict.txt")));

But what next?

How I can get prediction?

Thanx!


Solution

  • So, adding this code solved my problem:

    int batchSize = 150;
    DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize);
    DataSet allData = iterator.next();
    
    DataNormalization normalizer = new NormalizerStandardize();
    normalizer.fit(allData);
    normalizer.transform(allData);
    
    INDArray output = model.output(allData.getFeatures());
    
    // Output
    System.out.println(output);