Speed up calculation of a special 3-array possible?

I would like to determine a 3-array R with dimensions (K,d,d) from two matrices A of dim (K,N) and X with dim (d,N) where K is small, d is moderate but N is large (see code example below for typical values). The formula for the array is

R[k, i, j] = sum( A[k, ] * X[i, ] * X[j, ] ).

This array has to be calculated numerous times, so speed is of the essence. Hence, I would like to know what might be the most efficient way to compute this in R?

My current approach

My current approach is found below as "current" along with the "naive" approach, which is unsurprisingly considerably slower.


K = 3
d = 20
N = 1e5

tt = microbenchmark(
  current = {
    for(krow in 1:K){
      tmp = X * matrix(A[krow,], d, N, byrow = TRUE)
      R[krow,,] = tmp %*% t(X)  
  naive = {
    for(krow in 1:K){
      for(irow in 1:d){
        for(jrow in 1:d){
          Ralt[krow, irow, jrow] = sum(A[krow,] * X[irow, ] * X[jrow,])

  check = "equal",
  setup = {
    A = matrix(runif(K*N), K, N)
    X = matrix(runif(d*N), d, N)
    R = array(0, dim = c(K, d, d))
    Ralt = array(0, dim = c(K, d, d))
  times = 5



  • Do you see any way to improve upon this? For example, is it possible to use the fact that R is symmetric in the last two indices?
  • Could I expect a substantial (>30%) improvement from implementing this in Rcpp?


  • You can transpose t the matrix to enable column subsetting, what is faster than row subsetting. And this allows auto repetition instead of creating a new matrix.

    tX <- t(X)
    tA <- t(A)
    for(krow in 1:K){
        . <- tX * tA[,krow]
        R[krow,,] <- t(.) %*% tX

    A variant might look like:

    tX <- t(X)
    tA <- t(A)
    for(krow in 1:K) R[krow,,] <- crossprod(tX * tA[,krow], tX)

    Where its possible to speed up crossprod e.g. by Rfast::Crossprod (Tanks to @jblood94 for the comment).

    A Rcpp variant can look like (but is currently slower than the others):

    Rcpp::cppFunction(r"(void mmul(Rcpp::NumericMatrix A, Rcpp::NumericMatrix X, Rcpp::NumericVector R, int K, int d) {
      int KD = d*K;
      for(int i=0; i < d; ++i) {
        for(int j=0; j < d; ++j) {
          Rcpp::NumericVector tmp = X(_,i) * X(_,j);
          for(int k=0; k < K; ++k) {
            R[k + i*K + j*KD] = sum(A(_,k) * tmp);
    } )")
    mmul(t(A), t(X), R, K, d)

    And one using Eigen:

    // [[Rcpp::depends(RcppEigen)]]
    // [[Rcpp::plugins(openmp)]]
    #include <omp.h>
    #include <RcppEigen.h>
    using namespace std;
    using namespace Eigen;
    // [[Rcpp::export]]
    void mmulE(Eigen::MatrixXd A, Eigen::MatrixXd X, Rcpp::NumericVector R, int n_cores) {
      for(int k=0; k < A.cols(); ++k) {
        Eigen::MatrixXd C = X.cwiseProduct(A.col(k).replicate(1, X.cols() ));
        Eigen::MatrixXd D = C.transpose() * X;
        for(int i=0; i<D.size(); ++i) {
          R[i*A.cols()+k] = D(i);
    mmulE(t(A), t(X), R, 1)

    K = 3
    d = 20
    N = 1e5
    tt = microbenchmark(
      current = {
        for(krow in 1:K){
          tmp = X * matrix(A[krow,], d, N, byrow = TRUE)
          R[krow,,] = tmp %*% t(X)  
      GKi = {
          tX <- t(X)
          tA <- t(A)
          for(krow in 1:K){
              . <- tX * tA[,krow]
              R[krow,,] <- t(.) %*% tX
      crossp = {
          tX <- t(X)
          tA <- t(A)
          for(krow in 1:K) R[krow,,] <- crossprod(tX * tA[,krow], tX)
      Rfast = {
        tX <- t(X)
        tA <- t(A)
        for(krow in 1:K) R[krow,,] <- Rfast::Crossprod(tX*tA[,krow], tX)
      Rcpp = mmul(t(A), t(X), R, K, d),
      RcppE1C = mmulE(t(A), t(X), R, 1),
      RcppE2C = mmulE(t(A), t(X), R, 2),
      RcppE4C = mmulE(t(A), t(X), R, 4),
      check = "equal",
      setup = {
        A = matrix(runif(K*N), K, N)
        X = matrix(runif(d*N), d, N)
        R = array(0, dim = c(K, d, d))
      times = 5
    Unit: milliseconds
        expr       min        lq      mean    median        uq       max neval
     current 106.44215 108.73900 161.66269 159.30184 216.37502 217.45546     5
         GKi  84.56926  87.98166 111.04126  90.18420  97.30869 195.16249     5
      crossp 112.02929 113.01796 113.67749 113.93593 114.49450 114.90976     5
       Rfast  39.12859  42.11124  45.42296  46.83398  49.46175  49.57924     5
        Rcpp 156.28284 156.38025 182.19358 157.05552 159.86193 281.38735     5
     RcppE1C  38.94770  40.49375  42.71140  40.69852  46.57995  46.83707     5
     RcppE2C  35.03088  35.67732  36.73970  36.52070  36.64065  39.82895     5
     RcppE4C  31.40532  33.94128  34.53725  34.40168  34.64187  38.29608     5

