I'm trying to multiply N-dimensional (N>=3) arrays in Julia as batches of matrices, i.e. perform matrix multiplication along the last two dimensions, keeping the other dimensions intact.
For example, if x
has dimensions (d1,d2,4,3)
and y
has dimensions (d1,d2,3,2)
, the result of the multiplication should have (d1,d2,4,2)
, i.e. a batch of matrix multiplications should be performed.
This is exactly what happens in Python's numpy.matmul
:
If either argument is N-D, N > 2, it is treated as a stack of matrices residing in the last two indexes and broadcast accordingly.
np.matmul(randn(10,10,4,3), randn(10,10,3,2)).shape
(10, 10, 4, 2)
Is there a way to reproduce the behaviour of numpy.matmul
in Julia?
I hoped .*
would work, but:
julia> randn(10,10,4,3) .* randn(10,10,3,2)
ERROR: DimensionMismatch("arrays could not be broadcast to a common size")
Stacktrace:
[1] _bcs1 at ./broadcast.jl:485 [inlined]
[2] _bcs at ./broadcast.jl:479 [inlined] (repeats 3 times)
[3] broadcast_shape at ./broadcast.jl:473 [inlined]
[4] combine_axes at ./broadcast.jl:468 [inlined]
[5] instantiate at ./broadcast.jl:256 [inlined]
[6] materialize(::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{4},Nothing,typeof(*),Tuple{Array{Float64,4},Array{Float64,4}}}) at ./broadcast.jl:798
[7] top-level scope at REPL[80]:1
I understand a list comprehension might work in 3-D, but this would get really messy in higher dimensions. Is the best solution to reshape (or view) all but the last 2 dimensions, use a list comprehension, and reshape it back? Or is there a better way?
P.S. The closest thing I could find was this, but it's not quite the same. New to Julia, so might be missing something obvious to Julia users.
I'm not aware of any such functionality, but there may well be in some package. I think that in Julia it's more natural to organize the data as arrays of matrices, and broadcast the matrix multiplication over them:
D = [rand(50, 60) for i in 1:4, j in 1:3]
E = [rand(60, 70) for i in 1:4, j in 1:3]
D .* E # now you can use dot broadcasting!
That said, it's easy to make your own. I would make one change, though. Julia is column major, while numpy is "last dimension major", therefore you should let the matrices resided along the first two dimensions, not the last two.
First, I'll define an in-place method that multiplies into an array C
, and then a non-in-place method that calls the in-place version (I'll skip dimension checking etc):
# In-place version, note the use of the @views macro,
# which is essential to get in-place behaviour
using LinearAlgebra: mul! # fast in-place matrix multiply
function batchmul!(C, A, B)
for j in axes(A, 4), i in axes(A, 3)
@views mul!(C[:, :, i, j], A[:, :, i, j], B[:, :, i, j])
end
return C
end
# The non-in-place version
function batchmul(A, B)
T = promote_type(eltype(A), eltype(B))
C = Array{T}(undef, size(A, 1), size(B)[2:end]...)
return batchmul!(C, A, B)
end
You could also make it multi-threaded. On my computer 4 threads gives a 2.5x speedup (actually, for larger values of the last two dimensions, I get a 3.5x speedup) How much of a speedup you get depends on the sizes and shapes of the arrays involved:
function batchmul!(C, A, B)
Threads.@threads for j in axes(A, 4)
for i in axes(A, 3)
@views mul!(C[:, :, i, j], A[:, :, i, j], B[:, :, i, j])
end
end
return C
end
Edit: I noticed just now that you want general N-D, not just 4-D. Shouldn't be too hard to generalize, though. Anyway, all the more reason to go for arrays of matrices, where broadcasting will automatically work for all dimensionalities.
Edit2: Couldn't leave it, so here's one for the N-D case (there's still more to do, like handling non-1-based indexing (update: axes
should fix this)):
function batchmul!(C, A, B)
Threads.@threads for I in CartesianIndices(axes(A)[3:end])
@views mul!(C[:, :, Tuple(I)...], A[:, :, Tuple(I)...], B[:, :, Tuple(I)...])
end
return C
end