I want to compute batched matrix products and matrix-vector products for which the batch indices are different. With A
as a (m x n x N_a) matrix and B
as a (m x n x N_b) matrix
C[:,:,ia,ib] == A[:,:,ia] * B[:,:,ib]
and similarly a batched matrix-vector product, in which case B
is seen as (n * N_b) vectors
D[:,ia, ib] == A[:,:,ia] * B[:,:,ib]
Both of these operations are possible with Tullio.jl:
using Tullio
A = rand(10,20,5)
B = rand(10,20,3)
@tullio out[im, ia, ib] := A[im, il, ia] * B[im, il, ib]
C = rand(10,10,5)
D = rand(10,10,3)
@tullio out[im, il, ia, ib] := C[im, il, ia] * D[im, il, ib]
However, Tullio.jl is unfortunately outdated: It's not compatible with recent version of CUDA.jl that I require for other reasons. The solutions should work on GPUs and CPUs.
I tried to reproduce this behaviour with NNlibs.jl's batched_mul
and batched_vec
. The only way I could see making this work is to use repeat
a lot to make the indices the same. This however causes a lot of unnecessary computations and allocations. Is there anyway to do this that is more efficient using NNlib or another library, maybe by directly using batched_gemm
?
The other alternative I see would be to use KernelAbstractions, but ideally I'd like to avoid that, as I then also would have to define the gradients separately.
Edit: TensorOperations.jl also doesn't directly support the unequal batch indices, but maybe there's a way to write the problem so that it can be computed with it.
After I posted my question, the Tullio
library was updated (thanks a lot!). So to answer the question: "How to compute batched matrix products and matrix-vector products for which the batch indices are different?", the easiest way to go is to just use this library:
The example from above now also works with CUDA.jl >= v4 :
using Tullio
A = rand(10,20,5)
B = rand(10,20,3)
@tullio out[im, ia, ib] := A[im, il, ia] * B[im, il, ib]
C = rand(10,10,5)
D = rand(10,10,3)
@tullio out[im, il, ia, ib] := C[im, il, ia] * D[im, il, ib]