Search code examples
arraysfortranreshapeblas

How to speed up reshape in higher rank tensor contraction by BLAS in Fortran?


Related question Fortran: Which method is faster to change the rank of arrays? (Reshape vs. Pointer)

If I have a tensor contraction A[a,b] * B[b,c,d] = C[a,c,d] If I use BLAS, I think I need DGEMM (assume real values), then I can

  1. first reshape tensor B[b,c,d] as D[b,e] where e = c*d,
  2. DGEMM, A[a,b] * D[b,e] = E[a,e]
  3. reshape E[a,e] into C[a,c,d]

The problem is, reshape is not that fast :( I saw discussions in Fortran: Which method is faster to change the rank of arrays? (Reshape vs. Pointer) , in the above link, the author met some error messages, except reshape itself.

Thus, I am asking if there is a convenient solution.


Solution

  • [I have prefaced the size of dimensions with the letter n to avoid confusion in the below between the tensor and the size of the tensor]

    As discussed in the comments there is no need to reshape. Dgemm has no concept of tensors, it only knows about arrays. All it cares about is that those arrays are laid out in the correct order in memory. As Fortran is column major if you use a 3 dimensional array to represent the 3 dimensional tensor B in the question it will be laid out exactly the same in memory as a 2 dimensional array used to represent the 2 dimensional tensor D. As far as the matrix mult is concerned all you need to do now is get the dot products which form the result to be the right length. This leads you to the conclusion that if you tell dgemm that B has a leading dim of nb, and a second dim of nc*nd you will get the right result. This leads us to

    ian@eris:~/work/stack$ gfortran --version
    GNU Fortran (Ubuntu 7.4.0-1ubuntu1~18.04.1) 7.4.0
    Copyright (C) 2017 Free Software Foundation, Inc.
    This is free software; see the source for copying conditions.  There is NO
    warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
    
    ian@eris:~/work/stack$ cat reshape.f90
    Program reshape_for_blas
    
      Use, Intrinsic :: iso_fortran_env, Only :  wp => real64, li => int64
    
      Implicit None
    
      Real( wp ), Dimension( :, :    ), Allocatable :: a
      Real( wp ), Dimension( :, :, : ), Allocatable :: b
      Real( wp ), Dimension( :, :, : ), Allocatable :: c1, c2
      Real( wp ), Dimension( :, :    ), Allocatable :: d
      Real( wp ), Dimension( :, :    ), Allocatable :: e
    
      Integer :: na, nb, nc, nd, ne
      
      Integer( li ) :: start, finish, rate
    
      Write( *, * ) 'na, nb, nc, nd ?'
      Read( *, * ) na, nb, nc, nd
      ne = nc * nd
      Allocate( a ( 1:na, 1:nb ) ) 
      Allocate( b ( 1:nb, 1:nc, 1:nd ) ) 
      Allocate( c1( 1:na, 1:nc, 1:nd ) ) 
      Allocate( c2( 1:na, 1:nc, 1:nd ) ) 
      Allocate( d ( 1:nb, 1:ne ) ) 
      Allocate( e ( 1:na, 1:ne ) ) 
    
      ! Set up some data
      Call Random_number( a )
      Call Random_number( b )
    
      ! With reshapes
      Call System_clock( start, rate )
      d = Reshape( b, Shape( d ) )
      Call dgemm( 'N', 'N', na, ne, nb, 1.0_wp, a, Size( a, Dim = 1 ), &
                                                d, Size( d, Dim = 1 ), &
                                        0.0_wp, e, Size( e, Dim = 1 ) )
      c1 = Reshape( e, Shape( c1 ) )
      Call System_clock( finish, rate )
      Write( *, * ) 'Time for reshaping method ', Real( finish - start, wp ) / rate
      
      ! Direct
      Call System_clock( start, rate )
      Call dgemm( 'N', 'N', na, ne, nb, 1.0_wp, a , Size( a , Dim = 1 ), &
                                                b , Size( b , Dim = 1 ), &
                                                0.0_wp, c2, Size( c2, Dim = 1 ) )
      Call System_clock( finish, rate )
      Write( *, * ) 'Time for straight  method ', Real( finish - start, wp ) / rate
    
      Write( *, * ) 'Difference between result matrices ', Maxval( Abs( c1 - c2 ) )
    
    End Program reshape_for_blas
    ian@eris:~/work/stack$ cat in
    40 50 60 70
    ian@eris:~/work/stack$ gfortran -std=f2008 -Wall -Wextra -fcheck=all reshape.f90  -lblas
    ian@eris:~/work/stack$ ./a.out < in
     na, nb, nc, nd ?
     Time for reshaping method    1.0515256000000001E-002
     Time for straight  method    5.8608790000000003E-003
     Difference between result matrices    0.0000000000000000     
    ian@eris:~/work/stack$ gfortran -std=f2008 -Wall -Wextra  reshape.f90  -lblas
    ian@eris:~/work/stack$ ./a.out < in
     na, nb, nc, nd ?
     Time for reshaping method    1.3585931000000001E-002
     Time for straight  method    1.6730429999999999E-003
     Difference between result matrices    0.0000000000000000     
    

    That said I think it worth noting though that the overhead for reshaping is O(N^2) while the time for the matrix multiply is O(N^3). Thus for large matrices the percentage overhead due to the reshape will tend to zero. Now code performance is not the only consideration, code readability and maintainability is also very important. So, if you find the reshape method much more readable and the matrices you use are sufficiently large that the overhead is not of import, you may well use the reshapes as in this case code readability might be more important than the performance. Your call.