Search code examples
stata

k-fold cross validation: how to filter data based on a randomly generated integer variable in Stata


The following seems obvious, yet it does not behave as I would expect. I want to do k-fold cross validation without using SCC packages, and thought I could just filter my data and run my own regressions on the subsets.

First I generate a variable with a random integer between 1 and 5 (5-fold cross validation), then I loop over each fold number. I want to filter the data by the fold number, but using a boolean filter fails to filter anything. Why?

Bonus: what would be the best way to capture all of the test MSEs and average them? In Python I would just make a list or a numpy array and take the average.

gen randint = floor((6-1)*runiform()+1)

recast int randint

forval b = 1(1)5 {
    xtreg c.DepVar ///  // training set
    c.IndVar1 ///
    c.IndVar2 ///
    if randint !=`b' ///
    , fe vce(cluster uuid)

    xtreg c.DepVar /// // test set, needs to be performed with model above, not a               
    c.IndVar1 ///      // new model...
    c.IndVar2 ///
    if randint ==`b' ///
    , fe vce(cluster uuid)
}

EDIT: Test set needs to be performed with model fit to training set. I changed my comment in the code to reflect this.

Ultimately the solution to the filtering issue was I was using a scalar in quotes to define the bounds and I had:

replace randint = floor((`varscalar'-1)*runiform()+1)

instead of just

replace randint = floor((varscalar-1)*runiform()+1)

When and where to use the quotes in Stata is confusing to me. I cannot just use varscalar in a loop, I have to use `=varscalar', but I can for some reason use varscalar - 1 and get the expected result. Interestingly, I cannot use

replace randint = floor((`varscalar')*runiform()+1)

I suppose I should just use

replace randint = floor((`=varscalar')*runiform()+1)

So why is it ok to use the version with the minus one and without the equals sign??

The answer below is still extremely helpful and I learned much from it.


Solution

  • As a matter of fact, two different things are going on here that are not necessarily directly related. 1) How to filter data with a randomly generated integer value and 2) k-fold cross-validation procedure.

    For the first one, I will leave an example below that could help you work things out using Stata with some tools that can be easily transferable to other problems (such as matrix generation and manipulation to store the metrics). However, I would call neither your sketch of code nor my example "k-fold cross-validation", mainly because they fit the model, both in the testing and in training data. Nonetheless, the case should be that strictly speaking, the model should be trained in the training data, and using those parameters, assess the performance of the model in testing data.

    For further references on the procedure Scikit-learn has done brilliant work explaining it with several visualizations included.

    That being said, here is something that could be helpful.

    clear all
    set seed 4
    set obs 100
    *Simulate model
    gen x1 = rnormal()
    gen x2 = rnormal()
    gen y = 1 + 0.5 * x1 + 1.5 *x2 + rnormal()
    gen byte randint = runiformint(1, 5)
    tab randint
    /*
        randint |      Freq.     Percent        Cum.
    ------------+-----------------------------------
              1 |         17       17.00       17.00
              2 |         18       18.00       35.00
              3 |         21       21.00       56.00
              4 |         19       19.00       75.00
              5 |         25       25.00      100.00
    ------------+-----------------------------------
          Total |        100      100.00 
    */
    // create a matrix to store results
    matrix res = J(5,4,.)
    matrix colnames res = "R2_fold"  "MSE_fold" "R2_hold"  "MSE_hold"
    matrix rownames res ="1" "2" "3" "4" "5"
    // show formated empty matrix 
    matrix li res
    /*
    res[5,4]
        R2_fold  MSE_fold   R2_hold  MSE_hold
    1         .         .         .         .
    2         .         .         .         .
    3         .         .         .         .
    4         .         .         .         .
    5         .         .         .         .
    */
    
    // loop over different samples
    forvalues b = 1/5 {
        // run the model using fold == `b'
        qui reg y x1 x2 if randint ==`b' 
        // save R squared training
        matrix res[`b', 1] = e(r2) 
        // save rmse training
        matrix res[`b', 2] = e(rmse)  
    
        // run the model using fold != `b'
        qui reg y x1 x2 if randint !=`b' 
        // save R squared training (?)
        matrix res[`b', 3] = e(r2)
        // save rmse testing (?)
        matrix res[`b', 4] = e(rmse)  
    }
    
    // Show matrix with stored metrics
    mat li res 
    /*
    res[5,4]
         R2_fold   MSE_fold    R2_hold   MSE_hold
    1  .50949187  1.2877728  .74155365  1.0070531
    2  .89942838  .71776458  .66401888   1.089422
    3  .75542004  1.0870525  .68884359  1.0517139
    4  .68140328  1.1103964  .71990589  1.0329239
    5  .68816084  1.0017175  .71229925  1.0596865
    */
    
    // some matrix algebra workout to obtain the mean of the metrics
    mat U = J(rowsof(res),1,1)
    mat sum = U'*res
    /* create vector of column (variable) means */
    mat mean_res = sum/rowsof(res)
    // show the average of the metrics acros the holds
    mat li mean_res
    /*
    mean_res[1,4]
          R2_fold   MSE_fold    R2_hold   MSE_hold
    c1  .70678088  1.0409408  .70532425  1.0481599
    */