Since Armadillo (afaik) doesn't have a triangular solver, I'd like to use the LAPACK triangular solver available in dtrtrs
. I have looked at the following two (first, second) SO threads and pieced something together, but it isn't working.
I have created a fresh package using RStudio while also enabling RcppArmadillo. I have a header file header.h
:
#include <RcppArmadillo.h>
#ifdef ARMA_USE_LAPACK
#if !defined(ARMA_BLAS_CAPITALS)
#define arma_dtrtrs dtrtrs
#else
#define arma_dtrtrs DTRTRS
#endif
#endif
extern "C" {
void arma_fortran(arma_dtrtrs)(char* UPLO, char* TRANS, char* DIAG, int* N, int* NRHS,
double* A, int* LDA, double* B, int* LDB, int* INFO);
}
int trtrs(char uplo, char trans, char diag, int n, int nrhs, double* A, int lda, double* B, int ldb);
static int trisolve(const arma::mat &in_A, const arma::mat &in_b, arma::mat &out_x);
which essentially is the answer to the first linked question, with also a wrapper function and the main function. The meat of the functions go in trisolve.cpp
and is as follows:
#include "header.h"
int trtrs(char uplo, char trans, char diag, int n, int nrhs, double* A, int lda, double* B, int ldb) {
int info = 0;
wrapper_dtrtrs_(&uplo, &trans, &diag, &n, &nrhs, A, &lda, B, &ldb, &info);
return info;
}
static int trisolve(const arma::mat &in_A, const arma::mat &in_b, arma::mat &out_x) {
size_t rows = in_A.n_rows;
size_t cols = in_A.n_cols;
double *A = new double[rows*cols];
double *b = new double[in_b.size()];
//Lapack has column-major order
for(size_t col=0, D1_idx=0; col<cols; ++col)
{
for(size_t row = 0; row<rows; ++row)
{
// Lapack uses column major format
A[D1_idx++] = in_A(row, col);
}
b[col] = in_b(col);
}
for(size_t row = 0; row<rows; ++row)
{
b[row] = in_b(row);
}
int info = trtrs('U', 'N', 'N', cols, 1, A, rows, b, rows);
for(size_t col=0; col<cols; col++) {
out_x(col)=b[col];
}
delete[] A;
delete[] b;
return 0;
}
// [[Rcpp::export]]
arma::mat RtoRcpp(arma::mat A, arma::mat b) {
arma::uword n = A.n_rows;
arma::mat x = arma::mat(n, 1, arma::fill::zeros);
int info = trisolve(A, b, x);
return x;
}
There are (at least) two problems for me:
conflicting types for 'dtrtrs_'
from the header file. However, I don't see what is wrong with the inputs (and this is literally copied from the second linked thread).wrapper_dtrtrts_
is not correct. But from what I can tell from Armadillo's compiler_setup.hpp
, arma_fortran
should create a function called wrapper_dtrtrs_
for me. What is the name I should use here in the main cpp
file?Armadillo already uses dtrtrs
for solving tridiagonal problems. Some code references:
dtrtrs
being called in lapack::trtrs
: https://gitlab.com/conradsnicta/armadillo-code/blob/9.200.x/include/armadillo_bits/wrapper_lapack.hpp#L908lapack::trtrs
being called in auxlib::solve_tri
with a nice debug statement: https://gitlab.com/conradsnicta/armadillo-code/blob/9.200.x/include/armadillo_bits/auxlib_meat.hpp#L3983So if we can trigger this debug statement, we can be sure the dtrtrs
is indeed used:
#define ARMA_EXTRA_DEBUG
// [[Rcpp::depends(RcppArmadillo)]]
#include <RcppArmadillo.h>
// [[Rcpp::export]]
void testTrisolve() {
arma::mat A = arma::randu<arma::mat>(5,5);
arma::mat B = arma::randu<arma::mat>(5,5);
arma::mat X1 = arma::solve(A, B);
arma::mat X3 = arma::solve(arma::trimatu(A), B);
}
/*** R
testTrisolve()
*/
This produces a lot of debug messages, among them:
lapack::gesvx()
[...]
lapack::trtrs()
So we clearly see that dtrtrs
is used in the tridiagonal case.
As for your original questions:
dtrtrs
, but with slightly different signature (A
is const
).ARMA_BLAS_UNDERSCORE
and ARMA_USE_WRAPPER
. I am not sure if that is always the case, but for me the former is defined and the latter not (c.f. config.hpp
), leading to dtrtrs_
as name.Indeed, if I add a const
where Armadillo uses it and call the function as dtrtrs_
, your code compiles without errors or warnings (with the exception of an unused variable ...):
// [[Rcpp::depends(RcppArmadillo)]]
#include <RcppArmadillo.h>
extern "C" {
void arma_fortran(dtrtrs)(char* UPLO, char* TRANS, char* DIAG, int* N, int* NRHS,
const double* A, int* LDA, double* B, int* LDB, int* INFO);
}
int trtrs(char uplo, char trans, char diag, int n, int nrhs, double* A, int lda, double* B, int ldb) {
int info = 0;
dtrtrs_(&uplo, &trans, &diag, &n, &nrhs, A, &lda, B, &ldb, &info);
return info;
}
[...]