Search code examples
rdeep-learningneural-networkpredictioncross-validation

neuralnet function, R, bugging when doing hyperparameter brute search optimization


I am currently trying to fit a neural network with 3 hidden layers using the neuralnet package in R. This is a classification problem.

I wish to test a series of possible hidden layer parameters, and take the one which returns the lowest classification error rate - using 5-Fold-CV. Please find the code and dput output below:

structure(list(Pregnancies = c(-0.733555340911315, -0.0470227782635459, 
0.639509784384224, -1.0768216222352, -0.733555340911315, -0.0470227782635459, 
2.69910747232753, 2.35584119100365, -0.733555340911315, -0.0470227782635459, 
-0.0470227782635459, 0.296243503060339, 0.296243503060339, -0.0470227782635459, 
-0.733555340911315, -0.733555340911315, 0.639509784384224, 1.66930862835588, 
1.32604234703199, 1.32604234703199, -1.0768216222352, -1.0768216222352, 
-0.390289059587431, -0.733555340911315, 0.296243503060339, -0.390289059587431, 
0.639509784384224, 0.296243503060339, 1.32604234703199, -0.390289059587431, 
-0.390289059587431, 0.296243503060339, 1.32604234703199, -0.390289059587431, 
0.982776065708108, -0.733555340911315, 0.982776065708108, -0.733555340911315, 
-0.733555340911315, -0.733555340911315, 0.296243503060339, -0.0470227782635459, 
-1.0768216222352, -0.0470227782635459, 1.66930862835588, -0.733555340911315, 
1.32604234703199, 0.296243503060339, -1.0768216222352, -0.390289059587431
), Glucose = c(-1.06577093437178, -1.4301021648105, 1.48454767869922, 
-0.105261326851534, -0.204624389698456, 0.159706840740259, 0.722764196872819, 
0.126585819791285, -0.80080276677999, 1.21957951110743, -1.09889195532076, 
-0.602076641086146, -0.337108473494353, 1.94824197198486, -0.602076641086146, 
-0.668318682984094, -1.09889195532076, 1.81575788818896, 0.954611343515638, 
2.18008911862768, -0.701439703933068, -0.535834599188197, 0.656522154974871, 
-0.867044808677939, 0.822127259719741, -0.701439703933068, 0.590280113076922, 
0.259069903587181, -1.26449706006563, -0.370229494443327, -0.701439703933068, 
0.0603437778933366, -1.33073910196358, 0.689643175923845, 0.755885217821793, 
-1.66194931145332, -0.933286850575887, 0.0272227569443625, -1.33073910196358, 
0.159706840740259, 0.755885217821793, -1.26449706006563, -0.867044808677939, 
1.65015278344409, 1.12021644826051, -1.06577093437178, 1.28582155300538, 
-0.734560724882042, 1.35206359490333, -0.469592557290249), BloodPressure = c(-0.404485476478996, 
-1.79728923982648, 0.117815934776311, 1.16241875728693, -0.0562845356421244, 
1.5106196981238, 2.03292110937911, -0.0562845356421244, -0.404485476478996, 
0.466016875613183, -1.10088735815274, -0.926786887734304, 0.117815934776311, 
-0.578585946897432, 0.814217816450055, -1.79728923982648, -0.404485476478996, 
1.68472016854223, -0.404485476478996, -0.23038500606056, 1.5106196981238, 
-0.578585946897432, -1.10088735815274, -0.404485476478996, 1.24946899249614, 
-0.404485476478996, -0.578585946897432, 1.33651922770536, 0.640117346031619, 
0.291916405194747, -0.23038500606056, 0.814217816450055, 0.640117346031619, 
0.988318286868491, 0.117815934776311, -1.97138971024492, -1.79728923982648, 
1.68472016854223, 0.117815934776311, -1.27498782857118, -1.10088735815274, 
-1.10088735815274, 1.24946899249614, 0.117815934776311, -0.752686417315868, 
0.466016875613183, -1.44908829898961, 0.466016875613183, 0.466016875613183, 
0.291916405194747), SkinThickness = c(-0.550335739052064, 0.317989382362936, 
-0.936258015236509, 1.7651979180546, 0.125028244270714, 1.18631450377794, 
0.414469951409047, -0.260894031913731, -1.32218029142095, 0.703911658547381, 
-1.7081025676054, 0.414469951409047, 1.7651979180546, -0.357374600959842, 
-1.7081025676054, -1.32218029142095, -0.743296877144286, 0.510950520455159, 
1.28279507282405, 0.993353365685714, 3.01944531565405, 1.18631450377794, 
0.510950520455159, -1.51514142951318, -0.16441346286762, -0.839777446190398, 
0.60743108950127, -0.839777446190398, -0.260894031913731, 0.0285476752246028, 
-0.357374600959842, -1.32218029142095, 1.08983393473183, -1.03273858428262, 
-0.16441346286762, -1.03273858428262, 0.125028244270714, 2.15112019423905, 
-1.03273858428262, 0.0285476752246028, -0.0679328938215084, 0.221508813316825, 
-0.357374600959842, 0.414469951409047, -0.260894031913731, 0.510950520455159, 
0.317989382362936, -1.32218029142095, 2.6335230394696, 0.125028244270714
), Insulin = c(-0.541966337004785, -0.604415367734054, 0.301095577840341, 
0.87354502619197, -0.521149993428362, 0.925585885133028, -0.000741404017790264, 
-0.323394729452345, -0.0631904347470589, 1.02966760301514, -0.958293208533242, 
0.478034498239936, 0.634157075063107, -0.791762459921859, -0.666864398463322, 
-1.14564030072105, -1.2809465339678, 1.60211705136677, 2.03926026647165, 
1.64374973851962, -0.375435588393402, -0.042374091170636, -0.188088496205596, 
-1.12482395714463, -0.479517306275516, -0.583599024157631, -0.0631904347470589, 
1.28987189772043, -0.781354288133648, -0.21931301157023, -0.781354288133648, 
0.311503749628553, -1.02074223926251, -0.854211490651128, 0.852728682615547, 
-0.729313429192591, -0.854211490651128, 0.769463308309856, -1.1040076135682, 
0.0617076267114783, -0.0631904347470589, -1.33298739290885, -1.14564030072105, 
-0.115231293688116, 3.631710550068, -1.13523212893284, 0.301095577840341, 
-0.989517723897877, -0.479517306275516, -0.479517306275516), 
    BMI = c(-0.709569819818014, -0.25689403975692, -1.06858854193543, 
    2.05331338952039, 0.305048307905129, 1.03869526179725, 0.617238501050711, 
    -0.24128453009964, -1.47443579302469, -0.163236981813245, 
    -1.22468363850822, -1.34955971576646, 0.695286049337106, 
    0.211391249961454, -2.0675971600013, -1.3183406964519, -1.28712167713734, 
    0.164562720989617, 0.320657817562408, 0.788943107280781, 
    2.20940848609318, 1.38210447425739, -1.13102658056455, -2.03637814068674, 
    -0.584693742559781, 0.0396866437313833, -0.631522271531618, 
    0.383095856191524, -0.522255703930664, -0.0383609045550123, 
    0.913819184539014, -0.100798943184129, 2.1937989764359, -1.2402931481655, 
    0.195781740304174, -1.91150206342851, -0.615912761874339, 
    2.66208426615427, -0.9437124646772, -0.615912761874339, -0.491036684616106, 
    0.258219778933291, 0.742114578308943, 0.102124682360499, 
    0.211391249961454, -0.225675020442362, -0.334941588043315, 
    -1.47443579302469, 3.20841710415904, 0.148953211332337), 
    DiabetesPedigreeFunction = c(-1.15930102674571, -0.876165113386904, 
    0.308811116596248, 0.182972932881223, 0.10607182061093, 0.717785213670079, 
    -0.855192082767733, -1.02647183282429, -0.0407393937232661, 
    1.23162446383976, -0.809750516426196, 1.6336075507072, 3.11570171446194, 
    -0.795768496013416, -0.0267573733104856, 0.0955853053013444, 
    -0.547587633686561, -0.110649495787169, 0.766722285114811, 
    -0.855192082767733, 1.61962553029442, -1.13832799612654, 
    0.700307688154103, -0.575551674512122, -1.08239991447542, 
    1.28755254549089, -0.306397781566097, -0.935588700141221, 
    0.938002035171372, 0.696812183050908, -0.610506725544073, 
    -0.194541618263852, -0.830723547045367, 0.917029004552202, 
    -0.851696577664538, -0.614002230647269, -0.498650562241829, 
    -0.607011220440878, -0.753822434775074, 1.05684920868001, 
    -0.739840414362294, -0.568560664305732, -0.879660618490099, 
    -1.04744486344347, 0.155008892055662, -1.07191339916583, 
    0.312306621699443, -0.963552740966782, 0.910037994345811, 
    -0.330866317288462), Age = c(-0.980707216961708, -0.453522925075076, 
    2.18239853435809, 0.073661366811557, 0.179098225188884, -0.348086066697749, 
    2.18239853435809, 1.12802995058482, -0.875270358584382, -0.242649208320423, 
    -0.875270358584382, 0.28453508356621, 2.70958282624472, -0.453522925075076, 
    -0.875270358584382, -0.453522925075076, -0.0317754915657695, 
    2.92045654299937, 1.23346680896215, 1.12802995058482, 0.073661366811557, 
    -0.875270358584382, -0.664396641829729, -0.558959783452402, 
    -0.348086066697749, -0.242649208320423, -0.453522925075076, 
    -0.769833500207055, 0.60084565869819, -0.348086066697749, 
    -0.453522925075076, 0.389971941943537, 1.23346680896215, 
    -0.980707216961708, 1.0225930922075, -0.875270358584382, 
    -0.769833500207055, 0.073661366811557, -0.664396641829729, 
    -0.980707216961708, 0.706282517075516, -0.558959783452402, 
    -0.664396641829729, -0.664396641829729, 1.65521424247145, 
    -0.769833500207055, 0.917156233830169, -0.980707216961708, 
    -0.558959783452402, -0.769833500207055), Y = structure(c(1L, 
    2L, 2L, 2L, 2L, 1L, 2L, 2L, 1L, 2L, 1L, 1L, 2L, 1L, 1L, 1L, 
    1L, 2L, 1L, 2L, 1L, 1L, 1L, 1L, 1L, 2L, 1L, 1L, 1L, 1L, 1L, 
    1L, 1L, 1L, 1L, 1L, 1L, 2L, 1L, 1L, 1L, 1L, 2L, 2L, 2L, 1L, 
    2L, 1L, 2L, 1L), .Label = c("0", "1"), class = "factor")), row.names = c(NA, 
50L), class = "data.frame")
library(magrittr)
library(dplyr)
library(caret)
library(neuralnet)

set.seed(12345)
Foldos_5 <- createFolds(df_std$Y, k = 5)

tune_grid <- expand.grid(
     layer1 = 1:3,
     layer2 = 1:3,
     layer3 = 1:3)
  
#TRAIN::
erreur_5cv <- matrix(
  nrow = nrow(tune_grid),
  ncol = 5
)

for(params in 1:nrow(tune_grid)){
  for(j in 1:length(Foldos_5)){
    a <- c(1:5)
    idx_garde <- a[a!=j]
    donnee_train <- df_std[Foldos_5[idx_garde] %>% unlist,] 

    #neuralnet - cant use factors::
    donnee_train %<>%
      mutate(Y = ifelse(Y == 0, 0, 1))
    
    modele <- neuralnet(Y ~ ., 
                        data = donnee_train,
                        hidden = as.numeric(tune_grid[params,]),
                        linear.output = F, threshold = 0.01,
                        act.fct = 'tanh')
    
    idx_test <- a[a==j]
    
    donnee_test <- df_std[Foldos_5[[idx_test]],] %>%
      dplyr::select(-c("Y")) #il faut enlever "Y" pour la fonction predict
    donnee_check <- df_std[Foldos_5[[idx_test]],] %>%
      dplyr::select(c("Y")) 
    
    result <- predict(modele, newdata = donnee_test)
    
    pred <- ifelse(result[,1] > 0.5, 1, 0)
    mean( (donnee_check$Y != pred) ) -> erreur_5cv[params,j]
  }
}

The problem is that eventually when running this loop, I get the following error: "Error in cbind(1, pred) %*% weights[[num_hidden_layers + 1]] : requires numeric/complex matrix/vector arguments In addition: Warning message: Algorithm did not converge in 1 of 1 repetition(s) within the stepmax."

I have seen many posts on the forum attempting to resolve this issue - two of which are implemented here, in removing the predictor, "Y", variable, and another changing factors ("Y") into dummy variables.

The thing is though, if I try and reproduce the error using the last j, and params values from the loop, I do not get it anymore. And so it seems this error only occurs when running neuralnet within a loop. Each time I run this loop, I additionally get a different set of j and params values where it stops.

I am absolutely lost as to what could be causing this, and how to fix it, given that it works without the loop perfectly fine. Any pointers would be much appreciated.


Solution

  • Well, you're using the neuralnet package, and in some iteration, with the parameters passed, the algorithm used by the neuralnet() function didn't converge and therefore the function didn't return the net weights. Right after the error you can inspect the model and see that the weights are not there. For example, I changed the threshold used to .1 and the algorithm converged for all iterations, as per the code below.

    set.seed(10)
    for(params in 1:nrow(tune_grid)){
       for(j in 1:length(Foldos_5)){
          a <- c(1:5)
          idx_garde <- a[a!=j]
          donnee_train <- df_std[Foldos_5[idx_garde] %>% unlist,] 
    
          #neuralnet - cant use factors::
          donnee_train %<>%
          mutate(Y = ifelse(Y == 0, 0, 1))
    
          modele <- neuralnet(Y ~ ., 
                              data = donnee_train,
                              hidden = as.numeric(tune_grid[params,]),
                              linear.output = F, threshold = 0.1, 
                              act.fct = 'tanh')
    
          idx_test <- a[a==j]
          donnee_test <- df_std[Foldos_5[[idx_test]],] %>%
          dplyr::select(-c("Y")) #il faut enlever "Y" pour la fonction predict
          donnee_check <- df_std[Foldos_5[[idx_test]],] %>%
          dplyr::select(c("Y")) 
     
          result <- predict(modele, newdata = donnee_test)
     
          pred <- ifelse(result[,1] > 0.5, 1, 0)
          mean( (donnee_check$Y != pred) ) -> erreur_5cv[params,j]
      }
    }