Search code examples
neural-networkapache-spark-mllibcross-validationapache-spark-ml

Using cross-validation to choose network-architecture for multilayer perceptron in Apache Spark


I'm trying to decide on the best architecture for a multilayerPerceptron in Apache Spark and am wondering whether I can use cross-validation for that.

Some code:

// define layers
int[] layers = new int[] {784, 78, 35, 10};
int[] layers2 = new int[] {784, 28, 28, 10};
int[] layers3 = new int[] {784, 84, 10};
int[] layers4 = new int[] {784, 392, 171, 78, 10};

MultilayerPerceptronClassifier mlp = new MultilayerPerceptronClassifier()
        .setMaxIter(25)
        .setLayers(layers4);

ParamMap[] paramGrid = new ParamGridBuilder()
        .addGrid(mlp.seed(), new long[] {895L, 12345L})
        //.addGrid(mlp.layers(), new int[][] {layers, layers2, layers3})
        .build();

CrossValidator cv = new CrossValidator()
        .setEstimator(mlp)
        .setEvaluator(new MulticlassClassificationEvaluator())
        .setEstimatorParamMaps(paramGrid).setNumFolds(10);

CrossValidatorModel model = cv.fit(train);

As you can see I've defined some architectures in integer arrays (layers-layers4).

As is, I have to fit the model multiple times, manually changing the layers parameter for the learning algorithm.

What I want is to provide the different architectures in a ParamMap that I pass to a CrossValidator (the commented out line in the ParamMap).

I suspect this beeing possible since the layers() method seems to be known to the ParamGridBuilder, but it doesn't accept the provided arguments.

If I am correct in this assumption, what am I doing wrong and how can I get this to work as intended?


Solution

  • Looking at the code it seems syntactically correct. It not working may be a bug or intended, since it'd be rather expensive computationally. So I guess no, you can't use cv for that.

    I ended up using the following formula:

    Number of units in hidden-layer = ceil((Number of inputs + outputs) * (2/3))
    

    Source: http://www.faqs.org/faqs/ai-faq/neural-nets/part3/section-10.html.