Search code examples
javadeeplearning4jdl4j

Deeplearning4j - how to iterate multiple DataSets for large data?


I'm studying Deeplearning4j (ver. 1.0.0-M1.1) for building neural networks.

I use IrisClassifier from Deeplearning4j as an example, it works fine:

//First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
int numLinesToSkip = 0;
char delimiter = ',';
RecordReader recordReader = new CSVRecordReader(numLinesToSkip,delimiter);
recordReader.initialize(new FileSplit(new File(DownloaderUtility.IRISDATA.Download(),"iris.txt")));

//Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
int labelIndex = 4;     //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row
int numClasses = 3;     //3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2
int batchSize = 150;    //Iris data set: 150 examples total. We are loading all of them into one DataSet (not recommended for large data sets)

DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader,batchSize,labelIndex,numClasses);
DataSet allData = iterator.next();
allData.shuffle();
SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65);  //Use 65% of data for training

DataSet trainingData = testAndTrain.getTrain();
DataSet testData = testAndTrain.getTest();

//We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(trainingData);           //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
normalizer.transform(trainingData);     //Apply normalization to the training data
normalizer.transform(testData);         //Apply normalization to the test data. This is using statistics calculated from the *training* set

final int numInputs = 4;
int outputNum = 3;
long seed = 6;

log.info("Build model....");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
    .seed(seed)
    .activation(Activation.TANH)
    .weightInit(WeightInit.XAVIER)
    .updater(new Sgd(0.1))
    .l2(1e-4)
    .list()
    .layer(new DenseLayer.Builder().nIn(numInputs).nOut(3)
        .build())
    .layer(new DenseLayer.Builder().nIn(3).nOut(3)
        .build())
    .layer( new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
        .activation(Activation.SOFTMAX) //Override the global TANH activation with softmax for this layer
        .nIn(3).nOut(outputNum).build())
    .build();

//run the model
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
//record score once every 100 iterations
model.setListeners(new ScoreIterationListener(100));

for(int i=0; i<1000; i++ ) {
    model.fit(trainingData);
}

//evaluate the model on the test set
Evaluation eval = new Evaluation(3);
INDArray output = model.output(testData.getFeatures());

eval.eval(testData.getLabels(), output);
log.info(eval.stats());

For my project, I have inputs ~30000 records (in iris example - 150). Each record is a vector size ~7000 (in iris example - 4).

Obviously, I can't process the whole data in one DataSet - in will produce OOM for JVM.

How I can process data in multiple DataSets?

I assume it should be something like this (store DataSets in List and iterate):

...
    DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader,batchSize,labelIndex,numClasses);
    List<DataSet> trainingData = new ArrayList<>();
    List<DataSet> testData = new ArrayList<>();

    while (iterator.hasNext()) {
        DataSet allData = iterator.next();
        allData.shuffle();
        SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65);  //Use 65% of data for training
        trainingData.add(testAndTrain.getTrain());
        testData.add(testAndTrain.getTest());
    }
    //We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
    DataNormalization normalizer = new NormalizerStandardize();
    for (DataSet dataSetTraining : trainingData) {
        normalizer.fit(dataSetTraining);           //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
        normalizer.transform(dataSetTraining);     //Apply normalization to the training data
    }
    for (DataSet dataSetTest : testData) {
        normalizer.transform(dataSetTest);         //Apply normalization to the test data. This is using statistics calculated from the *training* set
    }

...

    for(int i=0; i<1000; i++ ) {
        for (DataSet dataSetTraining : trainingData) {
            model.fit(dataSetTraining);
        }
    }

But when I start evaluation, I got this error:

Exception in thread "main" java.lang.NullPointerException: Cannot read field "javaShapeInformation" because "this.jvmShapeInfo" is null
    at org.nd4j.linalg.api.ndarray.BaseNDArray.dataType(BaseNDArray.java:5507)
    at org.nd4j.linalg.api.ndarray.BaseNDArray.validateNumericalArray(BaseNDArray.java:5575)
    at org.nd4j.linalg.api.ndarray.BaseNDArray.add(BaseNDArray.java:3087)
    at com.aarcapital.aarmlclassifier.classification.FAClassifierLearning.main(FAClassifierLearning.java:117)

...

    Evaluation eval = new Evaluation(26);

    INDArray output = new NDArray();
    for (DataSet dataSetTest : testData) {
        output.add(model.output(dataSetTest.getFeatures())); // ERROR HERE
    }

    System.out.println("--- Output ---");
    System.out.println(output);

    INDArray labels = new NDArray();
    for (DataSet dataSetTest : testData) {
        labels.add(dataSetTest.getLabels());
    }

    System.out.println("--- Labels ---");
    System.out.println(labels);

    eval.eval(labels, output);
    log.info(eval.stats());

What is correct way to iterate miltiple DataSet for learning network?

Thanx!


Solution

  • Firstly, always use Nd4j.create(..) for ndarrays. Never use the implementation. That allows you to safely create ndarrays that will work whether you use cpus or gpus.

    2nd: Always use the RecordReaderDataSetIterator's builder rather than the constructor. It's very long and error prone.

    That is why we made the builder in the first place.

    Your NullPointer actually isn't coming from where you think it is. it's due to how you're creating the ndarray. There's no data type or anything so it can't know what to expect. Nd4j.create(..) will properly setup the ndarray for you.

    Beyond that you are doing things the right way. The record reader handles the batching for you.