Search code examples
c++rcpparmadillo

Transform arma::cube subview into NumericVector to use sugar


I pass a 3D array from R into C++ and ran into type conversion issues. How do we transform arma::cube subviews from RcppArmadillo into NumericVectors to operate on them using sugar functions from Rcpp like which_min?

Say you have a 3D cube Q with some numeric entries. My goal is to get the index of the minimum value of the column entries for each row i and for each third dimension k. In R syntax this is which.min(Q[i,,k]).

For example for i = 1 and k = 1

cube Q = randu<cube>(3,3,3);
which_min(Q.slice(1).row(1)); // this fails

I thought a conversion to NumericVector would do the trick, but this conversion fails

which_min(as<NumericVector>(Q.slice(1).row(1))); // conversion failed

How can I get this to work? Thank you for your help.


Solution

  • You have a couple of options here:

    1. You can just use the Armadillo function for this, the member function .index_min() (see Armadillo documentation here).
    2. You can use Rcpp::wrap(), which "transforms an arbitrary object into a SEXP" to turn the arma::cube subviews into a Rcpp::NumericVector and use the sugar function Rcpp::which_min().

    Initially I just had the first option there as the answer since it seems a more straightforward way to accomplish your objective, but I add the second option (in an update to the answer) since I now consider that arbitrary conversions may be a part of what you're curious about.

    I put the following C++ code in a file so-answer.cpp:

    // [[Rcpp::depends(RcppArmadillo)]]
    #include <RcppArmadillo.h>
    
    // [[Rcpp::export]]
    Rcpp::List index_min_test() {
        arma::cube Q = arma::randu<arma::cube>(3, 3, 3);
        int whichmin = Q.slice(1).row(1).index_min();
        Rcpp::List result = Rcpp::List::create(Rcpp::Named("Q") = Q,
                                               Rcpp::Named("whichmin") = whichmin);
        return result;
    }
    
    // [[Rcpp::export]]
    Rcpp::List which_min_test() {
        arma::cube Q = arma::randu<arma::cube>(3, 3, 3);
        Rcpp::NumericVector x = Rcpp::wrap(Q.slice(1).row(1));
        int whichmin = Rcpp::which_min(x);
        Rcpp::List result = Rcpp::List::create(Rcpp::Named("Q") = Q,
                                               Rcpp::Named("whichmin") = whichmin);
        return result;
    }
    

    We have one function that uses Armadillo's .index_min() and one that uses Rcpp::wrap() to enable the use of Rcpp::which_min().

    Then I use Rcpp::sourceCpp() to compile it, make the functions available to R, and demonstrate calling them with a couple of different seeds:

    Rcpp::sourceCpp("so-answer.cpp")
    set.seed(1)
    arma <- index_min_test()
    set.seed(1)
    wrap <- which_min_test()
    arma$Q[2, , 2]
    #> [1] 0.2059746 0.3841037 0.7176185
    wrap$Q[2, , 2]
    #> [1] 0.2059746 0.3841037 0.7176185
    arma$whichmin
    #> [1] 0
    wrap$whichmin
    #> [1] 0
    set.seed(2)
    arma <- index_min_test()
    set.seed(2)
    wrap <- which_min_test()
    arma$Q[2, , 2]
    #> [1] 0.5526741 0.1808201 0.9763985
    wrap$Q[2, , 2]
    #> [1] 0.5526741 0.1808201 0.9763985
    arma$whichmin
    #> [1] 1
    wrap$whichmin
    #> [1] 1
    library(microbenchmark)
    microbenchmark(arma = index_min_test(), wrap = which_min_test())
    #> Unit: microseconds
    #>  expr    min      lq     mean  median      uq    max neval cld
    #>  arma 12.981 13.7105 15.09386 14.1970 14.9920 62.907   100   a
    #>  wrap 13.636 14.3490 15.66753 14.7405 15.5415 64.189   100   a
    

    Created on 2018-12-21 by the reprex package (v0.2.1)