I'm learning the features of the Rcpp
package and have no previous experience with C++
. I have tried:
#include <RcppArmadillo.h>
// [[Rcpp::depends("RcppArmadillo")]]
// [[Rcpp::export]]
arma::mat VtoMatCpp(int n,
arma::vec x) {
arma::mat V = arma::eye<arma::mat>(n,n) ;
V.elem(find(trimatu(V))) = x;
return(V);
}
When I use in R sourceCpp('fun.cpp')
and then try VtoMatCpp(2,1:3)
get Error: Mat::elem(): size mismatch
. It seems that trimatu
function is not picking the indexes of the diagonal.
You are getting your error because your find
call is actually finding the non-zero elements (in this case your diagonal elements). This results in only 2 elements for your VtoMatCpp(2,1:3)
call where naturally 3 elements is too large to fit in it.
This is somewhat similar to my question here where I actually want to exclude the diagonal elements. Unfortunately, the best I could come up with right now is to basically copy how R does it with upper.tri
. Here is a working example with RcppArmadillo
.
library(inline)
src <- '
using namespace arma;
using namespace Rcpp;
vec x = as<vec>(x_);
int n = as<int>(n_);
mat V = eye<mat>(n,n);
// make empty matrices
mat Z(n,n,fill::zeros);
mat X(n,n,fill::zeros);
// fill matrices with integers
vec idx = linspace<mat>(1,n,n);
X.each_col() += idx;
Z.each_row() += trans(idx);
// assign upper triangular elements
// the >= allows inclusion of diagonal elements
V.elem(find(Z>=X)) = x;
return(wrap(V));
'
fun <- cxxfunction(signature(n_ = "integer", x_ = "vector"),
body=src, plugin="RcppArmadillo")
fun(2,1:3)
[,1] [,2]
[1,] 1 2
[2,] 0 3
which is exactly the same as base
R.
fun2 <- function(a,b){
dm <- diag(2)
dm[upper.tri(dm, diag=TRUE)] <- 1:3
dm
}
fun2(2,1:3)
[,1] [,2]
[1,] 1 2
[2,] 0 3
Running a quick benchmark does show that this implementation is faster than base
R. Here I wrapped the base
solution above as fun2
.
library(microbenchmark)
microbenchmark(fun(100, seq(5050)), fun2(100, seq(5050)))
Unit: microseconds
expr min lq mean median uq max neval
fun(100, seq(5050)) 117.823 154.106 241.2361 188.2575 242.0360 3392.611 100
fun2(100, seq(5050)) 545.042 592.988 736.6958 622.7405 650.7475 4057.011 100