Search code examples
matrixfortranlinear-algebra

Fortran nested do loops to calculate tensor product


As an exercise, I'm trying to calculate a tensor product using do loops in Fortran. I use the same notation in Wolfram Mathworld in my code to evaluate the tensor product of matrices A and B:

program test

implicit none
real, dimension (3,3) :: A
real, dimension (4,2) :: B
real, dimension (:,:), allocatable :: C
integer :: i, j, k, l, m, n, p, q, alpha, beta
integer :: Ccols, Crows

A = reshape( (/ 1, 0, 2, 1, 9, 5, 1, 4, -1 /), shape(A) )
B = reshape( (/ 1, -2, 9, 5, 4, -1, 3, 9 /), shape( B ) )

m = size(A,1) ! no. A rows
n = size(A,2) ! no. A cols
p = size(B,1) ! no. B rows
q = size(B,2) ! no. B cols
allocate(C(m*p, n*q)) ! m*p columns, n*q rows
C = 0

do i = 1, n  ! iterate over A cols
    do j = 1, m  ! iterate over A rows
        do k = 1, q  ! iterate over B cols
            do l = 1, p  ! iterate over B rows
                alpha = p*(i-1) + k  ! C row index
                beta = q*(j-1) + l  ! C col index
                C(beta, alpha) = A(i,j) * B(k,l)
            end do
        end do
    end do
end do

print *, C

end program test

The resulting matrix C does not agree with the result generated using Python's numpy.kron function which gives a 12x6 matrix

>> import numpy as np
>> A = np.array([[1,0,2],[1,9,5],[1,4,-1]])
>> B = np.array([[1,-2],[9,5],[4,-1],[3,9]])
>> np.kron(A,B)

   [[  1,  -2,   0,   0,   2,  -4],
    [  9,   5,   0,   0,  18,  10],
    [  4,  -1,   0,   0,   8,  -2],
    [  3,   9,   0,   0,   6,  18],
    [  1,  -2,   9, -18,   5, -10],
    [  9,   5,  81,  45,  45,  25],
    [  4,  -1,  36,  -9,  20,  -5],
    [  3,   9,  27,  81,  15,  45],
    [  1,  -2,   4,  -8,  -1,   2],
    [  9,   5,  36,  20,  -9,  -5],
    [  4,  -1,  16,  -4,  -4,   1],
    [  3,   9,  12,  36,  -3,  -9]]

What am I doing wrong in the Fortran code?


Solution

  • As @PierU has already pointed out, matrices A and B aren't actually initialised in the way that you think. Fortran uses column-major order, whilst Python's numpy arrays use row-major order. You would need to fix those before even considering the rest. I often use transpose() here.

    I (personally) think debugging would be easier with variables clearly named in terms of what they stand for (row and columns) rather than m, n, p, q. Also, l is an awkward variable name, because it can be misread as 1, I in different fonts or by ageing programmers without their reading glasses on. But I've left it for now!

    program test
       implicit none
       integer, parameter :: Arows = 3, Acols = 3
       integer, parameter :: Brows = 4, Bcols = 2
       real, allocatable :: A(:,:), B(:,:), C(:,:)
       integer i, j, k, l, alpha, beta
       character(len=*), parameter :: fmt = "f6.1"
    
       A = transpose(   reshape( [ 1, 0, 2, 1, 9, 5, 1, 4, -1 ], [ Acols, Arows ] )   )
       B = transpose(   reshape( [ 1, -2, 9, 5, 4, -1, 3, 9 ]  , [ Bcols, Brows ] )   )
       allocate( C(Arows*Brows,Acols*Bcols) )
    
       do i = 1, Arows
          do j = 1, Acols
             do k = 1, Brows
                do l = 1, Bcols
                   alpha = Brows * ( i - 1 ) + k
                   beta  = Bcols * ( j - 1 ) + l
                   C(alpha,beta) = A(i,j) * B(k,l)
                end do
             end do
          end do
       end do
    
       call printMatrix( "A:", A, fmt )
       call printMatrix( "B:", B, fmt )
       call printMatrix( "C:", C, fmt )
    
    contains
    
       subroutine printMatrix( title, M, fm )
          character(len=*), intent(in) :: title
          real, intent(in) :: M(:,:)
          character(len=*), intent(in) :: fm
          integer row
    
          write( *, * ) title
          do row = 1, size( M, 1 )
             write( *, "( *( 1x, " // fm // " ) )"  ) M(row,:)
          end do
          write( *, * )
    
       end subroutine printMatrix
    
    end program test
    

    Output:

     A:
         1.0     0.0     2.0
         1.0     9.0     5.0
         1.0     4.0    -1.0
    
     B:
         1.0    -2.0
         9.0     5.0
         4.0    -1.0
         3.0     9.0
    
     C:
         1.0    -2.0     0.0    -0.0     2.0    -4.0
         9.0     5.0     0.0     0.0    18.0    10.0
         4.0    -1.0     0.0    -0.0     8.0    -2.0
         3.0     9.0     0.0     0.0     6.0    18.0
         1.0    -2.0     9.0   -18.0     5.0   -10.0
         9.0     5.0    81.0    45.0    45.0    25.0
         4.0    -1.0    36.0    -9.0    20.0    -5.0
         3.0     9.0    27.0    81.0    15.0    45.0
         1.0    -2.0     4.0    -8.0    -1.0     2.0
         9.0     5.0    36.0    20.0    -9.0    -5.0
         4.0    -1.0    16.0    -4.0    -4.0     1.0
         3.0     9.0    12.0    36.0    -3.0    -9.0
    

    For reference, the (non-interactive) Python version is

    import numpy as np
    A = np.array( [ [1,0,2], [1,9,5] , [1,4,-1] ] )
    B = np.array( [ [1,-2], [9,5] , [4,-1], [3,9] ] )
    print( np.kron( A, B ) )
    

    with output

    [[  1  -2   0   0   2  -4]
     [  9   5   0   0  18  10]
     [  4  -1   0   0   8  -2]
     [  3   9   0   0   6  18]
     [  1  -2   9 -18   5 -10]
     [  9   5  81  45  45  25]
     [  4  -1  36  -9  20  -5]
     [  3   9  27  81  15  45]
     [  1  -2   4  -8  -1   2]
     [  9   5  36  20  -9  -5]
     [  4  -1  16  -4  -4   1]
     [  3   9  12  36  -3  -9]]