Search code examples
deeplearning4j

DeepLearning4j and DataVec read csv file with label


I have built a DL4j project. Everything is fine if I use MNIST dataset as follows:

    DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, rngSeed);
    DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, rngSeed);

However, I want to switch to my own csv file with the following format:

A  |  B  |  C  |  X  |  Y
-------------------------
1  | 100 |  5  |  15 |  6
...

X and Y are the outcomes (or labels). As I plan to perform regression analysis, so both X and Y are real numbers. So I read the csv file using the following code:

    RecordReader recordReaderTrain = new CSVRecordReader(1, ",");
    recordReaderTrain.initialize(new FileSplit(new File("src/main/resources/data/Data.csv")));
    DataSetIterator dataIterTrain = new RecordReaderDataSetIterator(recordReaderTrain, batchSize, 3, 2);

3 in the code means index of the labels and 2 means number of possible labels. There no much explanation about these two parameters. I guess they mean the labels start from the 4th column and has 2 labels.

When I run the code, it shows the following exception:

Exception in thread "main" java.lang.ArrayIndexOutOfBoundsException: 14

I think it is because dl4j does not recognize 15 as label.

So my question is: how can I properly read the csv file for a regression analysis?

Many thanks.


Solution

  • Right so we have examples for regression: https://github.com/deeplearning4j/dl4j-examples/tree/cc383de91bdf4e28e36859aa2e8749100cd63177/dl4j-examples/src/main/java/org/deeplearning4j/examples/feedforward/regression

    You need to pass regression true (it's an extra part of the constructor) to the RecordReaderDataSetIterator.