Tuesday, December 8, 2015

Parallel Matrix Multiplication in Java

In this post I would like to discuss my experiment today on parallelising matrix multiplication. First off, I would like to give credit to the following stackoverflow thread, to all the authors of the questions, answers, and the codes.
http://stackoverflow.com/questions/5484204/parallel-matrix-multiplication-in-java-6

The fascination on parallelised matrix multiplication started when I was implementing an RBM in Java. On that particular implementation, the training phase was very slow, where some training instances may last for half a day. In comparison, a library called Medal https://github.com/dustinstansbury/medal which is written on top of MATLAB only took less than 5 minutes for the same computation! As it turns out, that godly speed was possible because matrix operations in MATLAB is highly optimised and parallelised. That makes me wonder how to parallelise matrix operations, and in particular, matrix multiplication.

So today I decided to investigate further and this post will serve as the highlight of the day.



The task is to parallelise an O(N^3) serial matrix multiplication. It is not the one with 3 for-loops, but rather the one utilising divide and conquer. I present the the pseudo-code of the algorithm for clarity of discussion.


mult(a,b):

Given: two matrices a, b
To compute: c = axb
Algorithm:

Split a, b, c into quadrants of equal sizes
 a = |a11  a12|    b = |b11  b12|    c = |c11  c12|
     |a21  a22|        |b21  b22|        |c21  c22|

Hence   c11 = mult(a11,b11) + mult(a12,b21)
        c12 = mult(a11,b12) + mult(a12,b22)
        c21 = mult(a21,b11) + mult(a22,b21)
        c22 = mult(a21,b12) + mult(a22,b22)

As we can see the algorithm is recursive, hence the base case must be defined (e.g. when size <= 64). For the base case, we simply multiply the two matrices using the well-known 3 for-loops algorithm.

The recursive definition is useful because later on it will be used as fork and join points of the parallel algorithm. Forks and joins are good because they serve as a natural way of synchronising the algorithm.

Let's parallelise it. The idea is to parallelise the computation of c11, c12, c21, and c22 in mult(a,b). Since each of the partition of c is not affected by the values of the other partitions, we can perform their computations simultaneously.

In our implementation, we can create a shared variable c with the dimension of axb. All the mult(a,b) tasks can read from a and b, and modify the matrix c directly, hence completely eliminating the overhead of creating temporary placeholder matrices along the way.

However, there is a catch. Consider c11, which is the result of mult(a11, b11) + mult(a12, b21). Since mult(a11, b11) and mult(a12, b21) modify the same region of matrix c, their access to c must be synchronised. An easy way to do this is by executing mult(a11, b11) and mult(a12, b21) serially. As such we do not have to worry about synchronising their accesses to c (which can be an added benefit, because synchronisation mechanism always incurs overhead). Therefore, we run serially the mult(aik, bkj) operations in mult(a,b) that modify the same region of c.

To achieve fork and join in Java, we can use ExecutorService and FutureTask combination. Following is my implementation (which adapts a lot of ideas in the stackoverflow thread I provided above). If it is hard to read, I recommend reading it on my github repository instead for this experimet: https://github.com/prajogotio/pmult. The implementation makes several simplifying assumptions (such as size of matrices, number of threads, minimum threshold sizes, etc) which serve to reduce the complexity of the implementation.


package parallel;

import java.util.concurrent.*;


public class ParallelMatrixMultiplication {

    private double[][] a;
    private double[][] b;
    private double[][] c;
    private static final int MATRIX_SIZE = 1024,
                             POOL_SIZE = Runtime.getRuntime().availableProcessors(),
                             MINIMUM_THRESHOLD = 64;

    private final ExecutorService exec = Executors.newFixedThreadPool(POOL_SIZE);

    ParallelMatrixMultiplication(double[][] a, double[][] b) {
        // assumption : a and b are both double[MATRIX_SIZE][MATRIX_SIZE]
        this.a = a;
        this.b = b;
        this.c = new double[MATRIX_SIZE][MATRIX_SIZE];
    }


    // Debugging Code
    ParallelMatrixMultiplication() {
        a = new double[MATRIX_SIZE][MATRIX_SIZE];
        b = new double[MATRIX_SIZE][MATRIX_SIZE];
        c = new double[MATRIX_SIZE][MATRIX_SIZE];
        for (int i = 0; i < a.length; ++i) {
            for (int j = 0; j < a.length; ++j) {
                a[i][j] = 1.0;
                b[i][j] = 1.0;
            }
        }
    }

    public void check() {
        for (int i = 0; i < c.length; ++i) {
            for (int j = 0; j < c.length; ++j) {
                if (Math.abs(c[i][j]-a.length) > 1e-10) {
                    System.out.format("%.3f\n",c[i][j]);
                }
            }
        }
        System.out.println("DONE");
    }



    public void multiply() {
        //multiplyRecursive(0, 0, 0, 0, 0, 0, a.length);
        Future f = exec.submit(new MultiplyTask(a, b, c, 0, 0, 0, 0, 0, 0, a.length));
        try {
            f.get();
            exec.shutdown();
        } catch (Exception e) {

        }
    }

    public double[][] getResult() {
        return c;
    }

    class MultiplyTask implements Runnable{
        private double[][] a;
        private double[][] b;
        private double[][] c;
        private int a_i, a_j, b_i, b_j, c_i, c_j, size;

        MultiplyTask(double[][] a, double[][] b, double[][] c, int a_i, int a_j, int b_i, int b_j, int c_i, int c_j, int size) {
            this.a = a;
            this.b = b;
            this.c = c;
            this.a_i = a_i;
            this.a_j = a_j;
            this.b_i = b_i;
            this.b_j = b_j;
            this.c_i = c_i;
            this.c_j = c_j;
            this.size = size;
        }

        public void run() {
            //System.out.format("[%d,%d]x[%d,%d](%d)\n",a_i,a_j,b_i,b_j,size);
            int h = size/2;
            if (size <= MINIMUM_THRESHOLD) {
                for (int i = 0; i < size; ++i) {
                    for (int j = 0; j < size; ++j) {
                        for (int k = 0; k < size; ++k) {
                            c[c_i+i][c_j+j] += a[a_i+i][a_j+k] * b[b_i+k][b_j+j];
                        }
                    }
                }
            } else {
                MultiplyTask[] tasks = {
                    new MultiplyTask(a, b, c, a_i, a_j, b_i, b_j, c_i, c_j, h),
                    new MultiplyTask(a, b, c, a_i, a_j+h, b_i+h, b_j, c_i, c_j, h),

                    new MultiplyTask(a, b, c, a_i, a_j, b_i, b_j+h, c_i, c_j+h, h),
                    new MultiplyTask(a, b, c, a_i, a_j+h, b_i+h, b_j+h, c_i, c_j+h, h),

                    new MultiplyTask(a, b, c, a_i+h, a_j, b_i, b_j, c_i+h, c_j, h),
                    new MultiplyTask(a, b, c, a_i+h, a_j+h, b_i+h, b_j, c_i+h, c_j, h),

                    new MultiplyTask(a, b, c, a_i+h, a_j, b_i, b_j+h, c_i+h, c_j+h, h),
                    new MultiplyTask(a, b, c, a_i+h, a_j+h, b_i+h, b_j+h, c_i+h, c_j+h, h)
                };

                FutureTask[] fs = new FutureTask[tasks.length/2];

                for (int i = 0; i < tasks.length; i+=2) {
                    fs[i/2] = new FutureTask(new Sequentializer(tasks[i], tasks[i+1]), null);
                    exec.execute(fs[i/2]);
                }
                for (int i = 0; i < fs.length; ++i) {
                    fs[i].run();
                }
                try {
                    for (int i = 0; i < fs.length; ++i) {
                        fs[i].get();
                    }
                } catch (Exception e) {

                }
            }
        }
    }

    class Sequentializer implements Runnable{
        private MultiplyTask first, second;
        Sequentializer(MultiplyTask first, MultiplyTask second) {
            this.first = first;
            this.second = second;
        }
        public void run() {
            first.run();
            second.run();
        }

    }

}

It is important to keep the number of threads around the optimal size, because otherwise the overhead from creation and switching amongst threads will result in poorer performance.

As for the performance, the parallelised matrix multiplication takes around 1.4s to multiply 1024x1024 matrices, while the serialised algorithm takes around 15s, which is very satisfying to know. As you may have noticed, I have not optimised the implementation for the sake of comparing plain speed-up from the parallelisation method. However, there are various optimisations that can be performed, such as processing the base computation in larger chunks, using other more lightweight primitive types, implementing Strassen's formulation, and so on.

As a side note, even with parallelisation, matrix multiplication in MATLAB is still way faster in several order of magnitude. Also, while achieving some speed-up, the amount of additional complexity introduced by fork, join, serialisation, and synchronisation indeed is not something to be underestimated. Furthermore, increasing the size of the matrices from 1024x1024 to 2048x2048 (only 4 times bigger) already results in a massive drop in performance. Therefore I think one should think about parallel programming as an optimisation tool rather than a cure, and as of any optimisation, it comes with a huge price tag.

In conclusion, I personally think it was a great experiment and as a result I certainly understand a little bit more about parallel programming.