Search code examples
javamultithreadingmatrixmatrix-multiplicationjava-threads

Parallelized Matrix Multiplication


I am trying to parallelize the multiplication of two matrix A,B.

Unfortunately the serial implementation is still faster than the parallel one or the speedup is too low. (with matrix dimension = 512 the speedup is like 1.3). Probably something is fundamentally wrong. Can someone out there give me a tip?

double[][] matParallel2(final double[][] matrixA,
                        final double[][] matrixB,
                        final boolean parallel) {
    int rows = matrixA.length;
    int columnsA = matrixA[0].length;
    int columnsB = matrixB[0].length;

    Runnable task;
    List<Thread> pool = new ArrayList<>();

    double[][] returnMatrix = new double[rows][columnsB];
    for (int i = 0; i < rows; i++) {
        int finalI = i;
        task = () -> {
            for (int j = 0; j < columnsB; j++) {
                //  returnMatrix[finalI][j] = 0;
                for (int k = 0; k < columnsA; k++) {
                    returnMatrix[finalI][j] +=
                            matrixA[finalI][k] * matrixB[k][j];
                }
            }
        };
        pool.add(new Thread(task));
    }
    if (parallel) {
        for (Thread trd : pool) {
            trd.start();
        }
    } else {
        for (Thread trd : pool) {
            trd.run();
        }
    }
    try {
        for (Thread trd : pool) {
            trd.join();
        }
    } catch (
            Exception e) {
        e.printStackTrace();
    }
    return returnMatrix;
}

Solution

  • You can use one parallel stream to reduce the computation time (perhaps twice or more). Don't use nested parallelism because this gives the opposite effect!

    /**
     * Parallel Matrix multiplication
     *
     * @param m rows of 'a' matrix
     * @param n columns of 'a' matrix
     *          and rows of 'b' matrix
     * @param p columns of 'b' matrix
     * @param a first matrix 'm×n'
     * @param b second matrix 'n×p'
     * @return result matrix 'm×p'
     */
    static double[][] parallelMatrixMultiplication(
            int m, int n, int p, double[][] a, double[][] b) {
        return IntStream.range(0, m)
                .parallel() // comment this line to check the sequential stream
                .mapToObj(i -> IntStream.range(0, p)
                        .mapToDouble(j -> IntStream.range(0, n)
                                .mapToDouble(k -> a[i][k] * b[k][j])
                                .sum())
                        .toArray())
                .toArray(double[][]::new);
    }
    

    // test
    public static void main(String[] args) {
        // dimensions
        int m = 512;
        int n = 1024;
        int p = 512;
    
        // matrices
        double[][] a = randomMatrix(m, n);
        double[][] b = randomMatrix(n, p);
    
        long time = System.currentTimeMillis();
    
        // multiplication
        double[][] c = parallelMatrixMultiplication(m, n, p, a, b);
    
        System.out.println(System.currentTimeMillis() - time);
        // with    .parallel() the time is - 1495
        // without .parallel() the time is - 5823
    }
    
    static double[][] randomMatrix(int d1, int d2) {
        return IntStream.range(0, d1)
                .mapToObj(i -> IntStream.range(0, d2)
                        .mapToDouble(j -> Math.random() * 10)
                        .toArray())
                .toArray(double[][]::new);
    }