Search code examples
c++algorithmdynamic-programmingmemoization

What's wrong with my dynamic programming algorithm with memoization?


*Sorry about my poor English. If there is anything that you don't understand, please tell me so that I can give you more information that 'make sence'.

**This is first time asking question in Stackoverflow. I've searched some rules for asking questions correctly here, but there should be something I missed. I welcome all feedback.

I'm currently solving algorithm problems to improve my skill, and I'm struggling with one question for three days. This question is from https://algospot.com/judge/problem/read/RESTORE , but since this page is in KOREAN, I tried to translate it in English.

Question
If there are 'k' pieces of partial strings given, calculate shortest string that includes all partial strings. All strings consist only lowercase alphabets. If there are more than 1 result strings that satisfy all conditions with same length, choose any string.

Input
In the first line of input, number of test case 'C'(C<=50) is given. For each test case, number of partial string 'k'(1<=k<=15) is given in the first line, and in next k lines partial strings are given. Length of partial string is between 1 to 40.

Output
For each testcase, print shortest string that includes all partial strings.

Sample Input
3
3
geo
oji
jing
2
world
hello
3
abrac
cadabra
dabr

Sample Output
geojing
helloworld
cadabrac

And here is my code. My code seems to work perfect with Sample Inputs, and when I made test inputs for my own and tested, everything worked fine. But when I submit this code, they say my code is 'wrong'.

Please tell me what is wrong with my code. You don't need to tell me whole fixed code, I just need sample inputs that causes error with my code. Added code description to make my code easier to understand.

Code Description

Saved all input partial strings in vector 'stringParts'.
Saved current shortest string result in global variable 'answer'.
Used 'cache' array for memoization - to skip repeated function call.

Algorithm I designed to solve this problem is divided into two function - restore() & eraseOverlapped().

restore() function calculates shortest string that includes all partial strings in 'stringParts'.
Result of resotre() is saved in 'answer'.

For restore(), there are three parameters - 'curString', 'selected' and 'last'.
'curString' stands for currently selected and overlapped string result.
'selected' stands for currently selected elements of 'stringParts'. Used bitmask to make my algorithm concise.
'last' stands for last selected element of 'stringParts' for making 'curString'.

eraseOverlapped() function does preprocessing - it deletes elements of 'stringParts' that can be completly included to other elements before executing restore().

#include <algorithm>
#include <iostream>
#include <vector>
#include <cstring>
#include <string>
#define MAX 15
using namespace std;

int k;
string answer; // save shortest result string

vector<string> stringParts;
bool cache[MAX + 1][(1 << MAX) + 1]; //[last selected string][set of selected strings in Bitmask]

void restore(string curString, int selected=0, int last=0) {
    //base case 1
    if (selected == (1 << k) - 1) {
        if (answer.empty() || curString.length() < answer.length()) 
            answer = curString;
        return;
    }
    //base case 2 - memoization
    bool& ret = cache[last][selected];
    if (ret != false) return;

    for (int next = 0; next < k; next++) {
        string checkStr = stringParts[next];
        if (selected & (1 << next)) continue;

        if (curString.empty())
            restore(checkStr, selected + (1 << next), next + 1);
        else {
            int check = false;
            //count max overlapping area of two strings and overlap two strings.
            for (int i = (checkStr.length() > curString.length() ? curString.length() : checkStr.length())
                ; i > 0; i--) {
                if (curString.substr(curString.size()-i, i) == checkStr.substr(0, i)) {
                    restore(curString + checkStr.substr(i, checkStr.length()-i), selected + (1 << next), next + 1);
                    check = true;
                    break;
                }
            }
            if (!check) { // if there aren't any overlapping area
                restore(curString + checkStr, selected + (1 << next), next + 1);
            }
        }
    }
    ret = true;
}
//check if there are strings that can be completely included by other strings, and delete that string.
void eraseOverlapped() {
    //arranging string vector in ascending order of string length
    int vectorLen = stringParts.size();
    for (int i = 0; i < vectorLen - 1; i++) {
        for (int j = i + 1; j < vectorLen; j++) {
            if (stringParts[i].length() < stringParts[j].length()) {
                string temp = stringParts[i];
                stringParts[i] = stringParts[j];
                stringParts[j] = temp;
            }
        }
    }

    //deleting included strings
    vector<string>::iterator iter;
    for (int i = 0; i < vectorLen-1; i++) {
        for (int j = i + 1; j < vectorLen; j++) {
            if (stringParts[i].find(stringParts[j]) != string::npos) {
                iter = stringParts.begin() + j;
                stringParts.erase(iter);
                j--;
                vectorLen--;
            }
        }
    }
}

int main(void) {
    int C;
    cin >> C; // testcase
    for (int testCase = 0; testCase < C; testCase++) {
        cin >> k; // number of partial strings
        memset(cache, false, sizeof(cache)); // initializing cache to false
        string inputStr;
        for (int i = 0; i < k; i++) {
            cin >> inputStr;
            stringParts.push_back(inputStr);
        }
        eraseOverlapped();
        k = stringParts.size();

        restore("");
        cout << answer << endl;
        answer.clear();
        stringParts.clear();
    }
}

Solution

  • Thanks Everyone who tried to help me solve this problem. I actually solved this problem with few changes on my previous algorithm. These are main changes.

    1. In my previous algorithm I saved result of restore() in global variable 'answer' since restore() didn't return anything, but in new algorithm since restore() returns mid-process answer string I no longer need to use 'answer'.
    2. Used string type cache instead of bool type cache. I found out using bool cache for memoization in this algorithm was useless.
    3. Deleted 'curString' parameter from restore(). Since what we only need during recursive call is one previously selected partial string, 'last' can replace role of 'curString'.

    CODE

    #include <algorithm>
    #include <iostream>
    #include <vector>
    #include <cstring>
    #include <string>
    #define MAX 15
    using namespace std;
    
    int k;
    
    vector<string> stringParts;
    string cache[MAX + 1][(1 << MAX) + 1];
    
    string restore(int selected = 0, int last = -1) {
        if (selected == (1 << k) - 1) {
            return stringParts[last];
        }
    
        if (last == -1) {
            string ret = "";
            for (int next = 0; next < k; next++) {
                string resultStr = restore(selected + (1 << next), next);
                if (ret.empty() || ret.length() > resultStr.length())
                    ret = resultStr;
            }
            return ret;
        }
    
        string& ret = cache[last][selected];
        if (!ret.empty()) {
            cout << "cache used in [" << last << "][" << selected << "]" << endl;
            return ret;
        }
    
        string curString = stringParts[last];
        for (int next = 0; next < k; next++) {
            if (selected & (1 << next)) continue;
    
            string checkStr = restore(selected + (1 << next), next);
            int check = false;
            string resultStr;
            for (int i = (checkStr.length() > curString.length() ? curString.length() : checkStr.length())
                ; i > 0; i--) {
                if (curString.substr(curString.size() - i, i) == checkStr.substr(0, i)) {
                    resultStr = curString + checkStr.substr(i, checkStr.length() - i);
                    check = true;
                    break;
                }
            }
            if (!check)
                resultStr = curString + checkStr;
    
            if (ret.empty() || ret.length() > resultStr.length())
                ret = resultStr;
        }
        return ret;
    }
    
    void EraseOverlapped() {
        int vectorLen = stringParts.size();
        for (int i = 0; i < vectorLen - 1; i++) {
            for (int j = i + 1; j < vectorLen; j++) {
                if (stringParts[i].length() < stringParts[j].length()) {
                    string temp = stringParts[i];
                    stringParts[i] = stringParts[j];
                    stringParts[j] = temp;
                }
            }
        }
    
        vector<string>::iterator iter;
        for (int i = 0; i < vectorLen - 1; i++) {
            for (int j = i + 1; j < vectorLen; j++) {
                if (stringParts[i].find(stringParts[j]) != string::npos) {
                    iter = stringParts.begin() + j;
                    stringParts.erase(iter);
                    j--;
                    vectorLen--;
                }
            }
        }
    }
    
    int main(void) {
        int C;
        cin >> C;
        for (int testCase = 0; testCase < C; testCase++) {
            cin >> k;
            for (int i = 0; i < MAX + 1; i++) {
                for (int j = 0; j < (1 << MAX) + 1; j++)
                    cache[i][j] = "";
            }
            string inputStr;
            for (int i = 0; i < k; i++) {
                cin >> inputStr;
                stringParts.push_back(inputStr);
            }
            EraseOverlapped();
            k = stringParts.size();
            string resultStr = restore();
            cout << resultStr << endl;
            stringParts.clear();
        }
    }
    

    This algorithm is much slower than the 'ideal' algorithm that the book I'm studying suggests, but it was fast enough to pass this question's time limit.