Search code examples
c++algorithmdata-structuresconstraint-satisfaction

How to implement Cryptarithmetic using Constraint Satisfaction in C++


I'll start by explaining what a cryptarithmetic problem is, through an example:

  T W O
+ T W O
F O U R

We have to assign a digit [0-9] to each letter such that no two letters share the same digit and it satisfies the above equation.

One solution to the above problem is:

   7 6 5   
+  7 6 5       
 1 5 3 0  

There are two ways to solve this problem, one is brute force, this will work but it's not the optimal way. The other way is using constraint satisfaction.

Solution using Constraint Satisfaction
We know that R will always be even because its 2 * O
this narrows down O's domain to {0, 2, 4, 6, 8}
We also know that F can't be anything but 1, since F isn't an addition of two letters, it must be getting its value from carry generated by T + T = O
This also implies that T + T > 9, only then will it be able to generate a carry for F;
This tells us that T > 4 {5, 6, 7, 8, 9}
And as we go on doing this, we keep on narrowing down the domain and this helps us reduce time complexity by a considerable amount.

The concept seems easy, but I'm having trouble implementing it in C++. Especially the part where we generate constraints/domain for each variable. Keep in mind that there are carries involved too.

EDIT: I'm looking for a way to generate a domain for each variable using the concept I stated.


Solution

  • Here is how I solved it using backtracking

    My approach here was to smartly brute force it, I recursively assign every possible value [0-9] to each letter and check if there is any contradiction.

    Contradictions can be one of the following:

    • Two or more letters end up having the same value.
    • Sum of letters don't match the value of the result letter.
    • Sum of letters is already assigned to some letter.

    As soon as a contradiction occurs, the recursion for that particular combination ends.

    #include <bits/stdc++.h>
    
    using namespace std;
    
    vector<string> words, wordOg;
    string result, resultOg;
    bool solExists = false;
    
    void reverse(string &str){
        reverse(str.begin(), str.end());
    }
    
    void printProblem(){
        cout<<"\n";
        for(int i=0;i<words.size();i++){
            for(int j=0;j<words[i].size();j++){
                cout<<words[i][j];
            }
            cout<<"\n";
        }
        cout<<"---------\n";
        for(int i=0;i<result.size();i++){
            cout<<result[i];
        }
        cout<<"\n";
    }
    
    void printSolution(unordered_map<char, int> charValue){
        cout<<"\n";
        for(int i=0;i<words.size();i++){
            for(int j=0;j<words[i].size();j++){
                cout<<charValue[wordOg[i][j]];
            }
            cout<<"\n";
        }
        cout<<"---------\n";
        for(int i=0;i<result.size();i++){
            cout<<charValue[resultOg[i]];
        }
        cout<<"\n";
    }
    
    void solve(int colIdx, int idx, int carry, int sum,unordered_map<char, int> charValue, vector<int> domain){
    
    
        if(colIdx<words.size()){
            if(idx<words[colIdx].size()){
                char ch = words[colIdx][idx];
                if(charValue.find(ch)!=charValue.end()){
                    solve(colIdx + 1, idx, carry, sum + charValue[ch], charValue, domain);
                }
                else{
                    for(int i=0;i<10;i++){
                        if(i==0 && idx==words[colIdx].size()-1) continue;
                        if(domain[i]==-1){
                            domain[i] = 0;
                            charValue[ch] = i;
                            solve(colIdx + 1, idx, carry, sum + i, charValue, domain);
                            domain[i] = -1;
                        }
                    }
                }
            }
            else solve(colIdx + 1, idx, carry, sum, charValue, domain);
        }
        else{
            if(charValue.find(result[idx])!=charValue.end()){
                if(((sum+carry)%10)!=charValue[result[idx]]) return;
            }
            else{
                if(domain[(sum + carry)%10]!=-1) return;
                domain[(sum + carry)%10] = 0;
                charValue[result[idx]] = (sum + carry)%10;
            }
            carry = (sum+carry)/10;
            if(idx==result.size()-1 && (charValue[result[idx]]==0 || carry == 1)) return;
            if(idx+1<result.size()) solve(0, idx+1, carry, 0, charValue, domain);
            else{
                solExists = true;
                printSolution(charValue);
            }
        }
    }
    
    int main() {
        unordered_map<char, int> charValue;
        vector<int> domain(10,-1);
    
        int n;
        cout<<"\nEnter number of input words: ";
        cin>>n;
        cout<<"\nEnter the words: ";
        for(int i=0;i<n;i++){
            string inp;
            cin>>inp;
            words.push_back(inp);
        }
        cout<<"\nEnter the resultant word: ";
        cin>>result;
    
        printProblem();
    
        wordOg = words;
        resultOg = result;
    
        reverse(result);
        for(auto &itr: words) reverse(itr);
    
        solve(0, 0, 0, 0, charValue, domain);
    
        if(!solExists) cout<<"\nNo Solution Exists!";
        return 0;
    }