Search code examples
c++sparse-matrix

How to make Sparce Matirx operations faster?


Sparse matrix are matrices that the vast majority of elements are 0s, with only a few non-zero elements.Now I have to fill my Sparsematrix class such that matrices can do add,subtract,and multiply. I use COO to store my matrix.

template <class T>
class VecList{
    private:
        int capacity;
        int length;
        T* arr;
        void doubleListSize(){
            T * oldArr = arr;
            arr = new T[2*capacity];
            capacity = 2 * capacity;
            for(int i=0;i<length;i++){
                arr[i] = oldArr[i];
            }
            delete [] oldArr;
        }
    public:
        VecList(){
            length = 0;
            capacity = 100;
            arr = new T[capacity];
        }
        VecList(T* a, int n){
            length = n;
            capacity = 100 + 2*n;
            arr = new T[capacity];
            for(int i=0;i<n;i++){
                arr[i] = a[i];
            }
            for (int i = 0; i < n; i++)
            {
                cout << arr[i] << " ";
            }
            cout << endl;
            printList();
        }
        ~VecList(){
            delete [] arr;
        }
        int getLength(){
            return length;
        }
        bool isEmpty(){
            return length==0;
        }
        void insertEleAtPos(int i, T x){
            if(length==capacity)
                doubleListSize();
            if(i > length || i < 0)
                throw "Illegal position";
            for(int j=length;j>i;j--)
                arr[j] = arr[j-1];
            arr[i] = x;
            length++;
        }
        T deleteEleAtPos(int i){
            if(i >= length || i < 0)
                throw "Illegal position";
            T tmp = arr[i];
            for(int j=i;j<length-1;j++)
                arr[j] = arr[j+1];
            length--;
            return tmp;
        }
        void setEleAtPos(int i, T x){
            if(i >= length || i < 0)
                throw "Illegal position";
            arr[i] = x;
        }
        T getEleAtPos(int i){
            if(i >= length || i < 0)
                throw "Illegal position";
            return arr[i];
        }
        int locateEle(T x){
            for(int i=0;i<length;i++){
                if(arr[i]==x)
                    return i;
            }
            return -1;
        }
        void printList(){
            for(int i=0;i<length;i++)
                cout << arr[i] << " ";
        }
};

COO is using three VecList to store a matrix.

  1. rowIndex: Indicates the number of rows.
  2. colIndex: Indicates the number of cols.
  3. values: Indicates the value of elements. Below is my Sparsematrix class:
template <class T>
class SparseMatrix{
    private:
        int rows;
        int cols;
        VecList<int>* rowIndex;
        VecList<int>* colIndex;
        VecList<T>* values;
    public:
        SparseMatrix(){ //Create a 10x10 Sparse matrix
            rows = 10;
            cols = 10;
            rowIndex = new VecList<int>();
            colIndex = new VecList<int>();
            values = new VecList<T>();
        }
        SparseMatrix(int r, int c){ //Create a rxc Sparse matrix
            rows = r;
            cols = c;
            rowIndex = new VecList<int>();
            colIndex = new VecList<int>();
            values = new VecList<T>();
        }
        ~SparseMatrix(){
            delete rowIndex;
            delete colIndex;
            delete values;
        }
};

As you can see, I only need to focus on the non-zero elements. So for exaple: A sparse matrix

0 2 0
0 0 1
3 0 0
rows = Exact number of rows, 3
cols = Exact number of columns, 3
rowIndex = 0 1 2
colIndex = 1 2 0
values   = 2 1 3

The first vertical line means that, row[0]col[1] has a non-zero element "2". The second vertical line means that, row[1]col[2] has a non-zero element "1". And now I wirte several functions to achieve operations between matrices.

int findPos(int a, int b){ //If there is a non-zero element at (a, b), then return its position in            "rowIndex", else return -1.
            for (int i = 0; i < rowIndex->getLength(); i++)
            {
                 if(rowIndex->getEleAtPos(i) == a && colIndex->getEleAtPos(i) == b)return i;
                 else if(rowIndex->getEleAtPos(rowIndex->getLength() - 1 - i) == a && colIndex->getEleAtPos(colIndex->getLength()-1-i) == b)return rowIndex->getLength()-1 - i;
            }
            return -1;
        }
        void setEntry(int rPos, int cPos, T x){ // Set (rPos, cPos) = x
            int pos = findPos(rPos,cPos);
            //Find if there is a non-zero element at (rPos, cPos).
            if(x != 0){
            //If the origin matrix does not have an element at(rPos, cPos),insert x to the matrix.
            if (pos == -1)
            {
                rowIndex->insertEleAtPos(rowIndex->getLength(),rPos);
                colIndex->insertEleAtPos(colIndex->getLength(),cPos);
                values->insertEleAtPos(values->getLength(),x);
            }
            else{
                //If the origin matrix has an element at(rPos, cPos),replace it with x.
                rowIndex->setEleAtPos(pos,rPos);
                colIndex->setEleAtPos(pos,cPos);
                values->setEleAtPos(pos,x);
            }
           }
           else{
            //If x == 0 and the origin matrix has an element at(rPos, cPos), delete the element.
                if(pos != -1){
                    rowIndex->deleteEleAtPos(pos);
                    colIndex->deleteEleAtPos(pos);
                    values->deleteEleAtPos(pos);
                }
            }
        //If x == 0, and the origin matrix does not have an element at(rPos, cPos), nothing changed.
        }
T getEntry(int rPos, int cPos){
        //Get the element at (rPos, cPos)
            return findPos(rPos,cPos) == -1 ? 0 : values->getEleAtPos(findPos(rPos,cPos));
        }
        SparseMatrix<T> * add(SparseMatrix<T> * B){
            if(rows != B->rows || cols != B->cols)throw "Matrices have incompatible sizes";
            SparseMatrix<T> *C = new SparseMatrix<T>(rows,cols);//Create a new matrix C as result.
            for (int i = 0; i < rowIndex->getLength(); i++)
            {
//I call the two input matrices "A" and "B". I put every elements of A into C, and also put every elements of B into C. But I use "C->setEntry", which means when A[i][j] has an element and B[i][j] also has an element, "setEntry" will cover the prior one. So I use C->setEntry(i,j,C->getEntry(i,j) + A[i][j] or B[i][j]), in another word, setEntry with (oldvalue + newvalue).That's what I did.
                C->setEntry(rowIndex->getEleAtPos(i),colIndex->getEleAtPos(i),C->getEntry(rowIndex->getEleAtPos(i),colIndex->getEleAtPos(i))+values->getEleAtPos(i));
                C->setEntry(B->rowIndex->getEleAtPos(i),B->colIndex->getEleAtPos(i),C->getEntry(B->rowIndex->getEleAtPos(i),B->colIndex->getEleAtPos(i))+B->values->getEleAtPos(i));
            }
            return C;
        }
        SparseMatrix<T> * subtract(SparseMatrix<T> * B){
//The same method as add.
            if(rows != B->rows || cols != B->cols)throw "Matrices have incompatible sizes";
            SparseMatrix<T> *C = new SparseMatrix<T>(rows,cols);
            for (int i = 0; i < rowIndex->getLength(); i++)
            {
                C->setEntry(rowIndex->getEleAtPos(i),colIndex->getEleAtPos(i),C->getEntry(rowIndex->getEleAtPos(i),colIndex->getEleAtPos(i))-values->getEleAtPos(i));
                C->setEntry(B->rowIndex->getEleAtPos(i),B->colIndex->getEleAtPos(i),C->getEntry(B->rowIndex->getEleAtPos(i),B->colIndex->getEleAtPos(i))-B->values->getEleAtPos(i));
            }
            return C;
        }

        SparseMatrix<T> * multiply(SparseMatrix<T> * B){
            //perform multiplication if the sizes of the matrices are compatible.
            if(rows != B->cols || cols != B->rows)throw "Matrices have incompatible sizes"; 
            SparseMatrix<T> *C = new SparseMatrix<T>(rows,B->cols);
//I call the two input matrices as "A" and "B".
//My method is take a row of A first, let this row do the arithmetic with each column of B,then I finish a row in C. Then continue to the next row.
            for (int i = 0; i < rowIndex->getLength();i++)
            {
                for (int j = 0; j < B->colIndex->getLength(); j++)
                {
                    if (B->findPos(colIndex->getEleAtPos(i),B->colIndex->getEleAtPos(j)) != -1)
                    {
                        C->setEntry(rowIndex->getEleAtPos(i),B->colIndex->getEleAtPos(j),C->getEntry(rowIndex->getEleAtPos(i),B->colIndex->getEleAtPos(j))+(values->getEleAtPos(i)*B->values->getEleAtPos(j)));
                    }
                } 
            }
            return C;
        }

        void printMatrix(){
            for (int i = 0; i < rows; i++)
            {
                for (int j = 0; j < cols; j++)
                {
                    cout << getEntry(i,j) << " ";
                }
                cout << endl;
            }
        }

I've tested for several conditions and all of them shows that the add,subtract and multiply are doing well. But there's a 10000x10000 matrix(called "X" and "Y") test that I can't pass, X and Y does not have many non-zero elements.And they just do add, subtract and multiply. The time limit is 1 sec.(Not including printMatrix(), but including setEntry()) I exceeded it.How do I reduce the running time of my programme? (I also wonder if the COO storage is wrong, and if the findPos() funcion is spare.)Thank you. My tool is VSCode2022, with C++11, Windows 11. Here is a sample of the test code.

#include <iostream>
#include <algorithm>
#include <chrono>
using namespace std;
int main(){
    auto start = std::chrono::high_resolution_clock::now();
    SparseMatrix<int> X,Y;
    X.setEntry(1,3,4);
    X.setEntry(7,8,2);
    Y.setEntry(1,6,4);
    Y.setEntry(1,3,4);
    Y.setEntry(7,7,2);
    X.printMatrix();
    cout << endl;
    Y.printMatrix();
    cout << endl;
    X.add(&Y)->printMatrix();
    cout << endl;
    X.subtract(&Y)->printMatrix();
    cout << endl;
    Y.multiply(&X)->printMatrix();
    cout << "Done" << endl;
    auto stop = std::chrono::high_resolution_clock::now();
    auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(stop - start).count();
    cout << "Running Time:" << duration << "ms\n";
    return 0;
}

Solution

    • After two days of hard work, I got ACCEPTED in the test. Two 10000x10000 Sparse Matrices' operations are completed in 0.01s~0.02s.

    • And I have to say that use three VecList or any two-dimensional array to store a sparce matrix is absolutely not a good idea. Because there are thousands of 0s in the matrix. This time I use only one VecList to store a matrix, but in the VecList are struct rHead members.

    struct OrthNode{
        /*I apologize for misspelling ‘row’ as ‘rol’.It was only when the program 
        was mostly written that I noticed this.*/
        int rol, col, value;
        struct OrthNode *right;
    };
    struct rHead{
        int nums;
    //This variable 'nums' is useless.I didn't use it in the later program.
        struct OrthNode * right;
    };
    
    • The rHead are contained in VecList, how much rows a matrix has, then how many rHead are in the VecList. Each rHead leads a whole row, when I want to visit a row, I should visit its rHead first.
    • The OrthNode is like an ID-Card. It record all informations of a non-zero element, its col, itsrol, itsvalue. And it's worth noting that it has a *right which points to the next element in the same row.
    • Now, I use the rHead as the beginning of each row. And rHead->right points to the first non-zero element, an OrthNode, in this row. And this OrthNode's pointer OrthNode->right points to the next element in this row.No matter rHead or OrthNode,if there's no element on the right side, then it points to NULL.
    • In this way, when I want to visit A[3][4], first visit the rHead of the third row, that is VecList[3].Now I am at A[3][?],then find rHead->right,suppose it is rHead->right = A[3][1], then findA[3][1]->right, until I get A[3][4].
    • So I use one VecList to store a sparse matrix row by row. I put the elements in each row in the order of the columns(For example in row[3], rHead->A[3][1]->A[3][2]->A[3][4]->A[3][6].Say, A[3][k],the k is in order.)
    • Here is my code:
    template <class T>
    class VecList{
    //Only this member funcion has changed, so that it can return a correct thing.
    public:
            T* getEleAtPos(int i){
                if(i >= length || i < 0)
                    throw "Illegal position";
                return &arr[i];
            }
    //Others things in VecList are the same as in question.
    };
    template <class T>
    class SparseMatrix{
        private:
            VecList<rHead> M;//The VecList is the same as in the question.
            int totalrows;
            int totalcols;
        public:
            SparseMatrix(){
                totalrows = 10;
                totalcols = 10;
                for(int i=0; i<10; i++){
                    M.insertEleAtPos(M.getLength(), {0, nullptr});
                }
            }
            SparseMatrix(int r, int c){
                totalrows = r;
                totalcols = c;
                for(int i=0; i<r; i++){
                    M.insertEleAtPos(M.getLength(), {0, nullptr});
                }
            }
            ~SparseMatrix(){
                for (int i = 0; i < totalrows; i++)
                {
                    if(M.getEleAtPos(i)->right != NULL)
                    {
                        OrthNode * temp = M.getEleAtPos(i)->right;
                        while (temp != NULL)
                        {
                            M.getEleAtPos(i)->right = temp->right;
                            OrthNode*delNode = temp;
                            temp = temp->right;
                            delete delNode;
                        }
                        delete temp;
                    }
                }
            }
            void setEntry(int rPos, int cPos, T x){
                OrthNode* newNode = new OrthNode;
                newNode->rol = rPos;
                newNode->col = cPos;
                newNode->value = x;
                if (x == 0)
                {
                    if (M.getEleAtPos(rPos)->right == NULL)
                    {
                        delete newNode;
                        return;
                    }
                    OrthNode*temp = M.getEleAtPos(rPos)->right;
                    if (temp->col == cPos)
                    {
                        M.getEleAtPos(rPos)->right = temp->right;
                        delete newNode;
                        delete temp;
                        M.getEleAtPos(rPos)->nums --;
                        return;
                    }
                    
                    while(temp->col < cPos && temp != NULL){
                        if (temp->right->col == cPos)
                        {
                            OrthNode*delNode = temp->right;
                            temp = temp->right->right;
                            delete delNode;
                            M.getEleAtPos(rPos)->nums --;
                            return;
                        }
                        temp = temp->right;
                    }
                    if (temp->right == NULL)
                    {
                        delete temp;
                        delete newNode;
                        return;
                    }
                }
                else{
                    if (M.getEleAtPos(rPos)->right == NULL)
                    {
                        newNode->right = NULL;
                        M.getEleAtPos(rPos)->right = newNode;
                        M.getEleAtPos(rPos)->nums ++;
                        return;
                    }
                    else{
                        OrthNode* temp = M.getEleAtPos(rPos)->right;
                        if (cPos < temp->col)
                        {
                            newNode->right = temp;
                            M.getEleAtPos(rPos)->right = newNode;
                            M.getEleAtPos(rPos)->nums ++;
                            return;
                        }
                        while (temp->col <= cPos && temp != NULL)
                        {
                            if (temp->col == cPos)
                            {
                                temp->value = x;
                                M.getEleAtPos(rPos)->nums ++;
                                delete newNode;
                                return;
                            }
                            if (temp->right == NULL)
                            {
                                newNode->right =NULL;
                                temp->right = newNode;
                                M.getEleAtPos(rPos)->nums ++;
                                return;
                            } 
                            if (temp->col < cPos && temp->right->col > cPos)
                            {
                                newNode->right = temp->right;
                                temp->right = newNode;
                                M.getEleAtPos(rPos)->nums ++;
                                return;
                            }
                            temp = temp->right;
                        }    
                }
            }
        }
            T getEntry(int rPos, int cPos){
                OrthNode* read = new OrthNode;
                read = M.getEleAtPos(rPos)->right;
                while (read != NULL)
                {
                    if (read->col == cPos)
                    {
                        return read->value;
                    }
                    read = read->right;
                }
                delete read;
                return 0;
            }
            SparseMatrix<T> * add(SparseMatrix<T> * B){
                if(totalrows != B->totalrows || totalcols != B->totalcols)throw "Matrices have incompatible sizes";
                SparseMatrix<T>* C = new SparseMatrix(totalrows,totalcols);
                for (int i = 0; i < totalrows; i++)
                {
                    if (M.getEleAtPos(i)->right != NULL && B->M.getEleAtPos(i)->right == NULL)
                    {
                        OrthNode * tempA = M.getEleAtPos(i)->right;
                        while (tempA != NULL)
                        {
                            C->setEntry(tempA->rol,tempA->col,tempA->value);
                            tempA = tempA->right;
                        }
                        delete tempA;
                        continue;
                    }
                    else if (M.getEleAtPos(i)->right == NULL && B->M.getEleAtPos(i)->right != NULL)
                    {
                        OrthNode * tempB = B->M.getEleAtPos(i)->right;
                        while (tempB != NULL)
                        {
                            C->setEntry(tempB->rol,tempB->col,tempB->value);
                            tempB = tempB->right;
                        }
                        delete tempB;
                        continue;
                    }
                    else if (M.getEleAtPos(i)->right == NULL && B->M.getEleAtPos(i)->right == NULL)
                    {
                        continue;
                    }
                    else{
                        OrthNode * tempA = M.getEleAtPos(i)->right;
                        OrthNode * tempB = B->M.getEleAtPos(i)->right;
                        while (tempA != NULL)
                        {
                            C->setEntry(tempA->rol,tempA->col,tempA->value);
                            tempA = tempA->right;
                        }
                        delete tempA;
                        while (tempB != NULL)
                        {
                            int oldEntry = C->getEntry(tempB->rol,tempB->col);
                            C->setEntry(tempB->rol,tempB->col,oldEntry + tempB->value);
                            tempB = tempB->right;
                        }
                        delete tempB;
                        continue;
                    }
                }
            return C;
            }
            SparseMatrix<T> * subtract(SparseMatrix<T> * B){
                if(totalrows != B->totalrows || totalcols != B->totalcols)throw "Matrices have incompatible sizes";
                SparseMatrix<T>* C = new SparseMatrix(totalrows,totalcols);
                for (int i = 0; i < totalrows; i++)
                {
                    if (M.getEleAtPos(i)->right != NULL && B->M.getEleAtPos(i)->right == NULL)
                    {
                        OrthNode * tempA = M.getEleAtPos(i)->right;
                        while (tempA != NULL)
                        {
                            C->setEntry(tempA->rol,tempA->col,tempA->value);
                            tempA = tempA->right;
                        }
                        delete tempA;
                        continue;
                    }
                    else if (M.getEleAtPos(i)->right == NULL && B->M.getEleAtPos(i)->right != NULL)
                    {
                        OrthNode * tempB = B->M.getEleAtPos(i)->right;
                        while (tempB != NULL)
                        {
                            C->setEntry(tempB->rol,tempB->col, -(tempB->value));
                            tempB = tempB->right;
                        }
                        delete tempB;
                        continue;
                    }
                    else if (M.getEleAtPos(i)->right == NULL && B->M.getEleAtPos(i)->right == NULL)
                    {
                        continue;
                    }
                    else{
                        OrthNode * tempA = M.getEleAtPos(i)->right;
                        OrthNode * tempB = B->M.getEleAtPos(i)->right;
                        while (tempA != NULL)
                        {
                            C->setEntry(tempA->rol,tempA->col,tempA->value);
                            tempA = tempA->right;
                        }
                        delete tempA;
                        while (tempB != NULL)
                        {
                            int oldEntry = C->getEntry(tempB->rol,tempB->col);
                            C->setEntry(tempB->rol,tempB->col,oldEntry - tempB->value);
                            tempB = tempB->right;
                        }
                        delete tempB;
                        continue;
                    }
                }
            return C;
            }
    
            SparseMatrix<T> * multiply(SparseMatrix<T> * B){
                //perform multiplication if the sizes of the matrices are compatible.
                if (totalrows != B->totalcols || totalcols != B->totalrows)throw "Matrices have incompatible sizes";
                SparseMatrix<T>* C = new SparseMatrix(totalrows,B->totalcols);
                for (int i = 0; i < totalrows; i++)
                {
                    if (M.getEleAtPos(i)->right == NULL)
                    {
                        continue;
                    }
                    else{
                        OrthNode* tempA = M.getEleAtPos(i)->right;
                        while (tempA != NULL)
                        {
                            if (B->M.getEleAtPos(tempA->col) != NULL)
                            {
                            OrthNode* tempB = B->M.getEleAtPos(tempA->col)->right;
                            while (tempB != NULL)
                                {
                                    int oldEntry = C->getEntry(tempA->rol,tempB->col);
                                    C->setEntry(tempA->rol,tempB->col,oldEntry + tempA->value * tempB->value);
                                    tempB = tempB->right;
                                }
                            }
                            tempA = tempA->right;
                        }  
                    }
                }
                return C;
            }
    
            // Only call this function if you know the size of matrix is reasonable.
            void printMatrix(){
                // for (int i = 0; i < totalrows; i++)
                // {
                //     for (int j = 0; j < totalcols; j++)
                //     {
                //         cout << getEntry(i,j) << " ";
                //     }
                //     cout << endl;
                // }
                cout << "Be careful, when the matrix is too big, do not use Print!" << endl;
            }
    
    • And here is the test code:
    #include <iostream>
    #include <algorithm>
    #include <chrono>
    using namespace std;
    int main(){
        SparseMatrix<int> X(10,10);
        SparseMatrix<int> Y(10,10);
    //It is a Sparse Matrix so do not give so much elements.
        for (int i = 0; i < 10; i++)
        {
            X.setEntry(i,i,2);
            Y.setEntry(3,i,i+2);
        }
    //If you want to test the time cost, do not use printMatrix();
        auto start = std::chrono::high_resolution_clock::now();
        X.printMatrix();
        cout << endl;
        Y.printMatrix();
        cout << endl;
        X.add(&Y)->printMatrix();
        cout << endl;
        X.subtract(&Y)->printMatrix();
        cout << endl;
        X.multiply(&Y)->printMatrix();
        cout << endl;
        cout << "Done" << endl;
        auto stop = std::chrono::high_resolution_clock::now();
        auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(stop - start).count();
        cout << "Running Time:" << duration << "ms\n";
        return 0;
    }