Search code examples
c++vectorkaratsuba

c++ Karatsuba Multiplication using Vectors


So i've been trying to write out an algorithm for the Karatsuba Multiplication algorithm, and i've been attempting to use vectors as my data structure to handle the really long numbers which will be input...

My program can do smaller numbers fine, however it really struggles with larger numbers, and i get a core dump (Seg Fault). It also outputs strange results when the left hand side number is smaller than the right hand side.

Got any ideas? Heres the code.

#include <iostream>
#include <string>
#include <vector>

#define max(a,b) ((a) > (b) ? (a) : (b))

using namespace std;

vector<int> add(vector<int> lhs, vector<int> rhs) {
    int length = max(lhs.size(), rhs.size());
    int carry = 0;
    int sum_col;
    vector<int> result;

    while(lhs.size() < length) {
        lhs.insert(lhs.begin(), 0);
    }
    while(rhs.size() < length) {
        rhs.insert(rhs.begin(), 0);
    }

    for(int i = length-1; i >= 0; i--) {
        sum_col = lhs[i] + rhs[i] + carry;
        carry = sum_col/10;
        result.insert(result.begin(), (sum_col%10));
    }
    if(carry) {
        result.insert(result.begin(), carry);
    }
    int x = 0;
    while(result[x] == 0) {
        x++;
    }
    result.erase(result.begin(), result.begin()+x);
    return result;
}

vector<int> subtract(vector<int> lhs, vector<int> rhs) {
    int length = max(lhs.size(), rhs.size());
    int diff;
    vector<int> result;

    while(lhs.size() < length) {
        lhs.insert(lhs.begin(), 0);
    }
    while(rhs.size() < length) {
        rhs.insert(rhs.begin(), 0);
    }

    for(int i = length-1; i >= 0; i--) {
        diff = lhs[i] - rhs[i];
        if(diff >= 0) {
            result.insert(result.begin(), diff);
        } else {
            int j = i - 1;
            while(j >= 0) {
                lhs[j] = (lhs[j] - 1) % 10;
                if(lhs[j] != 9) {
                    break;
                } else {
                    j--;
                }
            }
            result.insert(result.begin(), diff+10);
        }
    }
    int x = 0;
    while(result[x] == 0) {
        x++;
    }
    result.erase(result.begin(), result.begin()+x);
    return result;
}

vector<int> multiply(vector<int> lhs, vector<int> rhs) {
    int length = max(lhs.size(), rhs.size());
    vector<int> result;

    while(lhs.size() < length) {
        lhs.insert(lhs.begin(), 0);
    }
    while(rhs.size() < length) {
        rhs.insert(rhs.begin(), 0);
    }

    if(length == 1) {
        int res = lhs[0]*rhs[0];
        if(res >= 10) {
            result.push_back(res/10);
            result.push_back(res%10);
            return result;
        } else {
            result.push_back(res);
            return result;
        }
    }

    vector<int>::const_iterator first0 = lhs.begin();
    vector<int>::const_iterator last0 = lhs.begin() + (length/2);
    vector<int> lhs0(first0, last0);
    vector<int>::const_iterator first1 = lhs.begin() + (length/2);
    vector<int>::const_iterator last1 = lhs.begin() + ((length/2) + (length-length/2));
    vector<int> lhs1(first1, last1);
    vector<int>::const_iterator first2 = rhs.begin();
    vector<int>::const_iterator last2 = rhs.begin() + (length/2);
    vector<int> rhs0(first2, last2);
    vector<int>::const_iterator first3 = rhs.begin() + (length/2);
    vector<int>::const_iterator last3 = rhs.begin() + ((length/2) + (length-length/2));
    vector<int> rhs1(first3, last3);

    vector<int> p0 = multiply(lhs0, rhs0);
    vector<int> p1 = multiply(lhs1,rhs1);
    vector<int> p2 = multiply(add(lhs0,lhs1),add(rhs0,rhs1));
    vector<int> p3 = subtract(p2,add(p0,p1));

    for(int i = 0; i < 2*(length-length/2); i++) {
        p0.push_back(0);
    }
    for(int i = 0; i < (length-length/2); i++) {
        p3.push_back(0);
    }

    result = add(add(p0,p1), p3);

    int x = 0;
    while(result[x] == 0) {
        x++;
    }
    result.erase(result.begin(), result.begin()+x);
    return result;
}

int main() {
    vector<int> lhs;
    vector<int> rhs;
    vector<int> v;

    lhs.push_back(2);
    lhs.push_back(5);
    lhs.push_back(2);
    lhs.push_back(5);
    lhs.push_back(2);
    lhs.push_back(5);
    lhs.push_back(2);
    lhs.push_back(5);

    rhs.push_back(1);
    rhs.push_back(5);
    rhs.push_back(1);
    rhs.push_back(5);
    rhs.push_back(1);
    rhs.push_back(5);
    rhs.push_back(1);


    v = multiply(lhs, rhs);

    for(size_t i = 0; i < v.size(); i++) {
        cout << v[i];
    }
    cout << endl;

    return 0;
    }

Solution

  • There are several issues with subtract. Since you don't have any way to represent a negative number, if rhs is greater than lhs your borrow logic will access before the beginning of of the data for lhs.

    You can also march past the end of result when removing leading zeros if the result is 0.

    Your borrow calculation is wrong, since -1 % 10 will return -1, and not 9, if lhs[j] is 0. A better way to calculate that is add 9 (one less than the value you're dividing by), lhs[j] = (lhs[j] + 9) % 10;.

    In an unrelated note, you can simplify your range iteration calculations. Since last0 and first1 have the same value, you can use last0 for both, and last1 is lhs.end(). This simpifies lhs1 to

    vector<int> lhs1(last0, lhs.end());
    

    and you can get rid of first1 and last1. Same goes for the rhs iterators.