Search code examples
c++matrix

Sparse matrix multiplication using linked lists


I have an assignment which requires students to conduct matrix multiplication given inputs of dimension mxn followed by the number of nonzero elements on each row. For example, the input

4 5
3 3 5 4 0
3 5 4 7 0
0
2 2 3 6 0

Represents the matrix

0 0 3 0 4
0 0 5 7 0
0 0 0 0 0
0 2 6 0 0

So far I have written this implementation of linked lists:

//LinkedList.h
#pragma once
#include<iostream>
using namespace std;

struct ListNode
{
    int val;
    ListNode* next;

    ListNode() : val(0), next(NULL) {};
    ListNode(int x) : val(x), next(NULL) {};

    friend class LinkedList;
};

class LinkedList
{

public:
    LinkedList();
    void Push_back(int x);
    void Push_front(int x);
    void Insert(int index, int x);
    void Delete(int index);
    void Reverse();
    void Swap(int index_1, int index_2);
    void Print();
    int Find(int);
    ~LinkedList();

private:
    ListNode* Head; //first
    ListNode* Tail;
};
//LinkedList.cpp
#include "LinkedList.h"

LinkedList::LinkedList() {
    // Constructor
    Head = NULL;
}

void LinkedList::Push_back(int x) {
    // TODO : Insert a node to the end of the linked list, the node¡¦s value is x.
    ListNode* temp = new ListNode(x);

    if (Head ==0) {
        Head = temp;
        return;
    }

    ListNode* current = Head;
    while (current->next != 0) {
        current = current->next;
    }
    current->next = temp;
}

void LinkedList::Push_front(int x) {
    if (!Head) {
        Head = new ListNode(x);
        return;
    }
    ListNode* temp = new ListNode(x);

    temp->next = Head;
    Head = temp;
}

void LinkedList::Insert(int index, int x) {
    // TODO : Insert a node to the linked list at position ¡§index¡¨, the node's
    // value is x. The index of the first node in the linked list is 0.
    
    ListNode* temp = new ListNode(x);
    temp->val = x;
    temp->next = NULL;
    
    if (Head == NULL) {
        Head = temp;
    }
    else if (index == 0) {
        temp->next = Head;
        Head = temp;
    }
    else {
        ListNode* cur = Head;
        int d = 1;
        while (cur != NULL) {
            if (d == index) {
                temp->next = cur->next;
                cur->next = temp;
                break;
            }
            cur = cur->next;
            d++;
        }
    }
}

void LinkedList::Delete(int index) {
    // TODO : Remove the node with index ¡§index¡¨ in the linked list.
    if (Head == NULL)
        return;

    ListNode* temp = Head;

    // If head needs to be removed
    if (index == 0) {
        Head = temp->next;
        temp = 0;
        return;
    }

    // previous node of the node to be deleted
    for (int i = 0; temp != NULL && i < index - 1; i++) {
        temp = temp->next;
    }
        
    // If position is more than number of nodes
    if (temp == NULL || temp->next == NULL)
        return;

    // Node temp->next is the node to be deleted
    // Store pointer to the next of node to be deleted
    ListNode* next = temp->next->next;

    // Unlink the node from linked list
    free(temp->next); // Free memory

    // Unlink the deleted node from list
    temp->next = next;

}

void LinkedList::Reverse() {
    // TODO : Reverse the linked list.
    // Example :
    //
    // Original List : 1 -> 2 -> 3 -> 4 -> NULL
    // Updated List  : 4 -> 3 -> 2 -> 1 -> NULL

    if (Head == 0 || Head->next == 0) {
        return;
    }

    ListNode *previous = 0;
    ListNode *current = Head;
    ListNode *preceding = Head->next;

    while (preceding != 0) {
        current->next = previous;
        previous = current;
        current = preceding;
        preceding = preceding->next;
    }

    current->next = previous;
    Head = current;
}

void LinkedList::Swap(int index_1, int index_2) 
{
    // TODO : Swap two nodes in the linked list
    // Example : 
    // index_1 : 1   ,  index_2 : 3
    // 
    // Original List : 1 -> 2 -> 3 -> 4 -> NULL
    // Updated List  : 1 -> 4 -> 3 -> 2 -> NULL
    
    if (index_1 == index_2) return;
    int value_1 = Find(index_1);
    int value_2 = Find(index_2);

    Delete(index_1);
    Insert(index_1, value_2);
    Delete(index_2);
    Insert(index_2, value_1);
    
}

void LinkedList::Print() {
    cout << "List: ";
    // TODO : Print all the elements in the linked list in order.
    if (Head == NULL) {
        cout << "List empty." << endl;
    }

    ListNode* current = Head;
    while (current != 0) {
        cout << current->val << " ";
        current = current->next;
    }
    cout << endl;
}

int LinkedList::Find(int index)
{

    ListNode* current = Head;

    // the index of the node we're currently
    // looking at
    int count = 0;
    while (current != NULL) {
        if (count == index) {
            return (current->val);
        }
        else {
            count++;
            current = current->next;
        }
    }
}

LinkedList::~LinkedList()
{
    while (Head) {
        ListNode* temp = Head;
        Head = Head->next;
        delete temp;
    }
}

I am trying to construct sparse matrices based on the algorithm in E. Horowitz's Fundamentals of Data Structures, Revised Edition though I am having a hard time understanding how matrix operations work with no use of vector/array. Where should I start?


Solution

  • You only really need to define Insert and Multiply:

    #include <iostream>
    
    struct Node {
      int value;
      int row_position;
      int column_position;
      Node* next;
    };
    
    void Insert(Node** head, int row, int col, int value) {
      Node* new_node = new Node();
      new_node->value = value;
      new_node->row_position = row;
      new_node->column_position = col;
      new_node->next = *head;
      *head = new_node;
    }
    
    void Multiply(Node* a, Node* b, Node** result) {
      Node *ptr_a, *ptr_b;
      for (ptr_a = a; ptr_a != nullptr; ptr_a = ptr_a->next) {
        for (ptr_b = b; ptr_b != nullptr; ptr_b = ptr_b->next) {
          if (ptr_a->column_position == ptr_b->row_position) {
            int row = ptr_a->row_position;
            int col = ptr_b->column_position;
            int sum = 0;
            Node* ptr = *result;
            bool found = false;
            // Check if the position already has a value
            while (ptr != nullptr) {
              if (ptr->row_position == row && ptr->column_position == col) {
                ptr->value += ptr_a->value * ptr_b->value;
                found = true;
                break;
              }
              ptr = ptr->next;
            }
            // If not found, insert new value
            if (!found) {
              sum = ptr_a->value * ptr_b->value;
              Insert(result, row, col, sum);
            }
          }
        }
      }
    }
    
    void PrintMatrix(Node* head, int rows, int cols) {
      for (int i = 0; i < rows; i++) {
        for (int j = 0; j < cols; j++) {
          Node* temp = head;
          bool found = false;
          while (temp != nullptr) {
            if (temp->row_position == i && temp->column_position == j) {
              std::cout << temp->value << " ";
              found = true;
              break;
            }
            temp = temp->next;
          }
          if (!found) {
            std::cout << "0 ";
          }
        }
        std::cout << '\n';
      }
    }
    
    void ReadMatrix(Node** head, int& rows, int& cols) {
      std::cin >> rows >> cols;
      for (int i = 0; i < rows; ++i) {
        while (true) {
          int col, value;
          std::cin >> col;
          if (col == 0) break; // A col value of 0 indicates the end of input for the current row
          std::cin >> value;
          Insert(head, i, col - 1, value); // Adjust column index to be 0-based
        }
      }
    }
    
    void DeleteList(Node* head) {
      while (head != nullptr) {
        Node* temp = head;
        head = head->next;
        delete temp;
      }
    }
    

    Example Usage:

    int main() {
      Node* a = nullptr;
      Node* b = nullptr;
      int rows_a, cols_a, rows_b, cols_b;
      std::cout << "Enter matrix A:" << '\n';
      ReadMatrix(&a, rows_a, cols_a);
      std::cout << "Matrix A (" << rows_a << "x" << cols_a << "):" << '\n';
      PrintMatrix(a, rows_a, cols_a);
      std::cout << "Enter matrix B:" << '\n';
      ReadMatrix(&b, rows_b, cols_b);
      std::cout << "Matrix B (" << rows_b << "x" << cols_b << "):" << '\n';
      PrintMatrix(b, rows_b, cols_b);
      if (cols_a != rows_b) {
        std::cout << "Matrices cannot be multiplied: incompatible dimensions." << '\n';
        DeleteList(a);
        DeleteList(b);
        return 1;
      }
      Node* result = nullptr;
      Multiply(a, b, &result);
      std::cout << "Resultant Matrix (" << rows_a << "x" << cols_b << "):" << '\n';
      PrintMatrix(result, rows_a, cols_b);
      DeleteList(a);
      DeleteList(b);
      DeleteList(result);
      return 0;
    }
    

    Output:

    Enter matrix A:
    4 5
    3 3 5 4 0
    3 5 4 7 0
    0
    2 2 3 6 0
    Matrix A (4x5):
    0 0 3 0 4 
    0 0 5 7 0 
    0 0 0 0 0 
    0 2 6 0 0 
    Enter matrix B:
    5 2
    1 8 2 7 0
    1 6 0
    2 5 0
    0
    1 3 2 4 0
    Matrix B (5x2):
    8 7 
    6 0 
    0 5 
    0 0 
    3 4 
    Resultant Matrix (4x2):
    12 31 
    0 25 
    0 0 
    12 30