Search code examples
c++rrcpprcpparmadillo

How do I find the indices of elements in a vector which are also in another vector using RcppArmadillo?


I am stuck trying to find the indices of elements in a vector x whose elements are also in another vector vals using Rcpp Armadillo. Both x and vals are of type arma::uvec.

In R, this would be straightforward:

x <- c(1,1,1,4,2,4,4)
vals <- c(1,4)
which(v %in% vals)

I've scanned the Armadillo docs and find() was my obvious first try; but it didn't work, since vals is a vector. I've also tried intersect() but it returns only the first unique indices.

What would be a good/efficient way to do this using Armadillo? Do I have to iterate through the elements in vals using find()?


Solution

  • A quick dirty way:

    Rcpp::cppFunction("
      arma::uvec ind(arma::uvec x, arma::uvec y){
       arma::vec a(x.size(), arma::fill::zeros);
       for (auto i:y) a = a +  (x==i);
       return arma::find(a) + 1;
      }
     ", 'RcppArmadillo')
    
    c(ind(v, vals))
    [1] 1 2 3 4 6 7