Search code examples
machine-learningdeeplearning4jdl4j

Different predictions for the same data


I use Deeplearning4j to classify equipment names. I marked ~ 50,000 items with 495 classes, and I use this data to train the neural network.

That is, as input, I provide a set of vectors (50,000) consisting of 0 and 1, and the expected class for each vector (0 to 494).

I use the IrisClassifier example as a basis for the code.

I saved the trained model to a file, and now I can use it to predict the class of equipment.

As an example, I tried to use for prediction the same data (50,000 items) that I used for training, and compare the prediction with my markup of this data.

The result turned out to be very good, the error of the neural network is ~ 1%.

After that, I tried to use for prediction the first 100 vectors from these 50,000 records, and removed the rest 49900.

And for these 100 vectors, the prediction is different when compared with the prediction for the same 100 vectors in the composition of 50,000.

That is, the less data we provide to the trained model, the greater the prediction error.

Even for exactly the same vectors.

Why does this happen?

My code.

Training:

 //First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
int numLinesToSkip = 0;
char delimiter = ',';
RecordReader recordReader = new CSVRecordReader(numLinesToSkip,delimiter);
recordReader.initialize(new FileSplit(new File(args[0])));

//Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
int labelIndex = 3331;
int numClasses = 495;
int batchSize = 4000;

// DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader,batchSize,labelIndex,numClasses);
DataSetIterator iterator = new RecordReaderDataSetIterator.Builder(recordReader, batchSize).classification(labelIndex, numClasses).build();

List<DataSet> trainingData = new ArrayList<>();
List<DataSet> testData = new ArrayList<>();

while (iterator.hasNext()) {
    DataSet allData = iterator.next();
    allData.shuffle();
    SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.8);  //Use 80% of data for training
    trainingData.add(testAndTrain.getTrain());
    testData.add(testAndTrain.getTest());
}

DataSet allTrainingData = DataSet.merge(trainingData);
DataSet allTestData = DataSet.merge(testData);

//We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(allTrainingData);           //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
normalizer.transform(allTrainingData);     //Apply normalization to the training data
normalizer.transform(allTestData);         //Apply normalization to the test data. This is using statistics calculated from the *training* set

long seed = 6;
int firstHiddenLayerSize = labelIndex/6;
int secondHiddenLayerSize = firstHiddenLayerSize/4;

//log.info("Build model....");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
        .seed(seed)
        .activation(Activation.TANH)
        .weightInit(WeightInit.XAVIER)
        .updater(new Sgd(0.1))
        .l2(1e-4)
        .list()
        .layer(new DenseLayer.Builder().nIn(labelIndex).nOut(firstHiddenLayerSize)
                .build())
        .layer(new DenseLayer.Builder().nIn(firstHiddenLayerSize).nOut(secondHiddenLayerSize)
                .build())
        .layer( new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                .activation(Activation.SOFTMAX) //Override the global TANH activation with softmax for this layer
                .nIn(secondHiddenLayerSize).nOut(numClasses).build())
        .build();

//run the model
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();

//record score once every 100 iterations
model.setListeners(new ScoreIterationListener(100));

for(int i=0; i<5000; i++ ) {
    model.fit(allTrainingData);
}

//evaluate the model on the test set
Evaluation eval = new Evaluation(numClasses);

INDArray output = model.output(allTestData.getFeatures());

eval.eval(allTestData.getLabels(), output);
log.info(eval.stats());

// Save the Model
File locationToSave = new File(args[1]);
model.save(locationToSave, false);

Prediction:

// Open the network file
File locationToLoad = new File(args[0]);
MultiLayerNetwork model = MultiLayerNetwork.load(locationToLoad, false);
model.init();

// First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
int numLinesToSkip = 0;
char delimiter = ',';

// Data to predict
CSVRecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter);  //skip no lines at the top - i.e. no header
recordReader.initialize(new FileSplit(new File(args[1])));

//Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
int batchSize = 4000;

DataSetIterator iterator = new RecordReaderDataSetIterator.Builder(recordReader, batchSize).build();

List<DataSet> dataSetList = new ArrayList<>();

while (iterator.hasNext()) {
    DataSet allData = iterator.next();
    dataSetList.add(allData);
}

DataSet dataSet = DataSet.merge(dataSetList);

DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(dataSet);
normalizer.transform(dataSet);

// Now use it to classify some data
INDArray output = model.output(dataSet.getFeatures());

// Save result
BufferedWriter writer = new BufferedWriter(new FileWriter(args[2], true));
for (int i=0; i<output.rows(); i++) {
    writer
            .append(output.getRow(i).argMax().toString())
            .append(" ")
            .append(String.valueOf(i))
            .append(" ")
            .append(output.getRow(i).toString())
            .append('\n');
}
writer.close();

Solution

  • Ensure you save the normalizer as follows alongside the model:

    import org.nd4j.linalg.dataset.api.preprocessor.serializer.NormalizerSerializer; 
    NormalizerSerializer SUT = NormalizerSerializer.getDefault(); 
    
    SUT.write(normalizer,new File("outputFile.bin")); 
    
    NormalizeStandardize restored = SUT.restore(new File("outputFile.bin");