I have built a DL4j project. Everything is fine if I use MNIST dataset as follows:
DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, rngSeed);
DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, rngSeed);
However, I want to switch to my own csv file with the following format:
A | B | C | X | Y
-------------------------
1 | 100 | 5 | 15 | 6
...
X
and Y
are the outcomes (or labels). As I plan to perform regression analysis, so both X
and Y
are real numbers. So I read the csv file using the following code:
RecordReader recordReaderTrain = new CSVRecordReader(1, ",");
recordReaderTrain.initialize(new FileSplit(new File("src/main/resources/data/Data.csv")));
DataSetIterator dataIterTrain = new RecordReaderDataSetIterator(recordReaderTrain, batchSize, 3, 2);
3
in the code means index of the labels
and 2
means number of possible labels
. There no much explanation about these two parameters. I guess they mean the labels start from the 4th column and has 2 labels.
When I run the code, it shows the following exception:
Exception in thread "main" java.lang.ArrayIndexOutOfBoundsException: 14
I think it is because dl4j does not recognize 15
as label.
So my question is: how can I properly read the csv file for a regression analysis?
Many thanks.
Right so we have examples for regression: https://github.com/deeplearning4j/dl4j-examples/tree/cc383de91bdf4e28e36859aa2e8749100cd63177/dl4j-examples/src/main/java/org/deeplearning4j/examples/feedforward/regression
You need to pass regression true (it's an extra part of the constructor) to the RecordReaderDataSetIterator.