Search code examples
c++algorithmperformancetime-complexity

How can I apply meet-in-the-middle algorithm for searching whole 2D matrix


I try to solve the F. Xor-Paths from Codeforces.

Problem statement

There is a rectangular grid of size n×m . Each cell has a number written on it; the number on the cell (i,j ) is ai,j . Your task is to calculate the number of paths from the upper-left cell (1,1 ) to the bottom-right cell (n,m ) meeting the following constraints:

  1. You can move to the right or to the bottom only. Formally, from the cell (i,j ) you may move to the cell (i,j+1 ) or to the cell (i+1,j ). The target cell can't be outside of the grid.
  2. The xor of all the numbers on the path from the cell (1,1 ) to the cell (n,m ) must be equal to k (xor operation is the bitwise exclusive OR, it is represented as '^' in Java or C++ and "xor" in Pascal).

Find the number of such paths in the given grid.


I know that when N = 20 and M = 20 then brute force won't work. But I don't have any better idea, I will explain why but let me explain my current approach.

That's the code:

#include <iostream>
#include <vector>
#include <unordered_set>

using namespace std;

bool bfs(int i, int j, long long k, long long curr_xor, const vector<vector<long long>>& a, int& count, vector<vector<unordered_set<long long>>>& bad_cells) {
    if (j == a[0].size() || i == a.size())
        return false;

    if (bad_cells[i][j].find(curr_xor) != bad_cells[i][j].end())
        return false;

    long long new_xor = curr_xor ^ a[i][j];

    if (i == a.size() - 1 && j == a[0].size() - 1) {
        if (new_xor == k) {
            ++count;
            return true;
        }

        return false;
    }

    bool right_valid = bfs(i, j + 1, k, new_xor, a, count, bad_cells);
    bool down_valid = bfs(i + 1, j, k, new_xor, a, count, bad_cells);

    if (!right_valid) bad_cells[i][j + 1].insert(new_xor);
    if (!down_valid) bad_cells[i + 1][j].insert(new_xor);

    return right_valid || down_valid;
}

long long solve(int n, int m, long long k, const vector<vector<long long>>& a) {
    int count = 0;
    vector<vector<unordered_set<long long>>> bad_cells(n + 1, vector<unordered_set<long long>>(m + 1));
    bfs(0, 0, k, 0, a, count, bad_cells);

    return count;
}

int main()
{
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);

    int n, m;
    long long k;
    cin >> n >> m >> k;

    vector<vector<long long>> a(n, vector<long long>(m));

    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < m; ++j) {
            cin >> a[i][j];
        }
    }

    cout << solve(n, m, k, a);
}

So as you can see, my approach is pure brute force. I have recursive method bfs that increments the final count when we reach the final point (n, m), and from every cell (i, j) it goes down and right. During moving over grid it calculates the XOR.

It also maintaints the bad_cells 2D matrix, that on (i, j) stores the values of XOR that won't drive us to good result. In other words, when we are at destination (n, m) point and our recursion starts to go back, then we check if the child met the problem conditions (final XOR == k). If no, we say that if you are at (i, j) with XOR equal to new_xor (parent XOR), if you go right or down (depends on what child returns), then you won't get valid result, I was here before and I know. Something like that.

Finally, I return true if any of the path (right or down) is correct, because if one of them is, then I can also look for other correct paths after the some correct move.


Of course, this is too slow. I get the TLE on 4-th testcase at Codeforces. The problem on Codeforces has tag meet-in-the-middle, so I probably need to use this technique. But how?

I imagine that I can solve it for e.g. (n/2)*m grid, and then for the another half and combine the results. But what will be the destination point for the first half (there is no (n, m) in the first half range). I think I could solve the first half in loop, by moving the destination point like (n/2, i) and then solve the second half also in loop by using (n/2, i) as starting point and (n, m) as the destination point, and then combine the results.

But it probably will be even slower. So how to solve if in a more efficient way? How to apply Meet-In-The-Middle technique to this problem?


Solution

  • In advance, sorry for the length of the answer...

    I have structured my answer as the complete reasonning that leads to the implementation of the meet-in-the-middle (MITM) algorithm for this question.
    Instead of just implementing and seeing if it works, I detailed some of the underlying math, the overall algorithm, the subtle traps to avoid.

    In this answer:

    • First, I will address the elephant in the room as to exactly why the bruteforce approach did not work (and likely won't work even with optimized versions of it).
    • I will propose an alternative implementation, which I think is slightly easier to understand than yours and is easier to take as a starting point for optimization.
    • I will list facts and properties the algorithm may (or not) use as exhaustively as possible. We will see what makes the algorithm valid here (but not yet how it saves time).
    • Then, I will go into the overview of MITM.
    • Next, I will detail where the magic happens in a separate point.
      The entire algorithm lies on that trick so it deserves its own separate section.
    • Finally, I will include an implementation.

    As you are only trying to solve a problem on codeforces, I strongly recommend you give it a try before reading the final section.

    Why it failed / Why MITM is promising for this problem

    The question specifies the matrix in input may be as big as 20x20 cells and the program should execute within 3secs and under 256MB of memory.

    On a 20x20 matrix, it takes 38 moves to go from the top left cell to the bottom right cells: 19 of these moves are to go right and 19 to go down. The only choice we have is which moves, numbered from 1 to 38, will go e.g. right.
    Therefore, the total number of paths in this situation is binomial(38, 19) = 35,345,263,800, too many to fit in the limitations set by the site.
    Remember this figure, it will be important later.

    The MITM algorithm should be able to cut path lengths in 2.
    What we hope is to make combinatorics work in our favor, i.e. get a number of half paths to be in the range of binomial(19, 9) = 92,378 (in fact 2x figure since each complete path is made of 2 half paths).
    Nothing is confirmed yet (the actual figure is in fact bigger and this difference will not materialize without heavy lifting anyway) but for now, let us simply consider this as a strong clue that something could happen.

    Alternative brute-force approach

    The issue (at least IMHO) with the code in the question is that I find it difficult to optimize. I prefer working with the below brute-force code instead:

    • The algorithm is not recursive and does not use backtracking nor anything complicated. It only has a nested 2-deep loop.
    • It uses a Matrix helper class to avoid seeing std::vector<std::vector<...>>.
      It encapsulates a single vector, and maps row/column coordinates to vector indices.
    • Unlike you, I have translated the constraint given in the question that 0 ≤ k ≤ 1018 into uint64_t (unsigned long long would have worked). I acknowledge signed long long works too though.
    • For illustration, I have included a measure of the processing time in std::cerr.
    #include <chrono>
    #include <iostream>
    #include <vector>
    
    //The matrix is not the core of the problem, so this simple class will do.
    class Matrix {
    public:
        Matrix(unsigned int rows, unsigned int columns) : rowCount(rows), colCount(columns), v(rows* columns) {}
        uint64_t* operator[](unsigned int r) { return &(v[r * colCount]); }
    
    private:
        unsigned int rowCount, colCount;
        std::vector<uint64_t> v;
    };
    
    //Read the matrix from std::cin
    void read(Matrix& inputMatrix, unsigned int rows, unsigned int columns)
    {
        for (unsigned int r = 0; r < rows; ++r) {
            for (unsigned int c = 0; c < columns; ++c) {
                uint64_t value = 0; std::cin >> value;
                inputMatrix[r][c] = value;
            }
        }
    }
    
    int main(int argc, char* argv[])
    {
        using std::chrono::steady_clock;
        using std::chrono::duration_cast;
        using std::chrono::duration;
        using std::chrono::milliseconds;
    
        using XorValueList = std::vector<uint64_t>;
    
        unsigned int rows, columns;
        uint64_t expected;
        std::cin >> rows >> columns >> expected;
    
        // Construct the a matrix of the right size
        Matrix inputMatrix(rows, columns);
        // Fill the matrix with data from std::in
        read(inputMatrix, rows, columns);
    
        // xorRow represents the current row being processed. Since we can never go up, we do not need to remember what is above it.
        // It stores the list of XOR results for all the paths leading to the corresponding cell
        std::vector<XorValueList> xorRow;
    
        auto start = steady_clock::now();
    
        // On the first row, only 1 path (horizontal) can reach each of the cells and there are no cells above.
        // The top left cell is a special case.
        xorRow.push_back({ inputMatrix[0][0] });
        // First row
        for (unsigned int c = 1; c < columns; ++c) {
            auto v = xorRow[c-1][0];
            xorRow.push_back({ v ^ inputMatrix[0][c] });
        }
    
    
        //Next rows
        for (unsigned int r = 1; r < rows; ++r) {
            // xorRow2 is the next row after xorRow.
            // It will be filled, then swapped with xorRow for the next loop / for after the loop
            std::vector<XorValueList> xorRow2;
            { // The first column is always a special case
                auto v = xorRow[0][0];
                xorRow2.push_back({ v ^ inputMatrix[r][0] });
            }
            //The other cells consist in taking the matrix value at that cell and combining it using xor with what we calculated on its left and above it.
            for (unsigned int c = 1; c < columns; ++c) {
                xorRow2.push_back({ });
                for (auto v : xorRow2[c - 1])
                    xorRow2[c].push_back(v ^ inputMatrix[r][c]);
                for (auto v : xorRow[c])
                    xorRow2[c].push_back(v ^ inputMatrix[r][c]);
            }
            std::swap(xorRow, xorRow2);
        }
    
        //The bottom right cell of the matrix contains all the desired path values
        unsigned int result = 0;
        for (auto v : xorRow[columns - 1]) {
            if (v == expected)
                result += 1;
        }
        std::cout << result;
    
        auto end = steady_clock::now();
        auto ms_int = duration_cast<milliseconds>(end - start);
        std::cerr << "\nTime spent: " << ms_int.count() << ' ' << " ms\n" << std::endl;
    
        return 0;
    }
    

    Useful properties of XOR (notation: ⨁, C++ (bitwise): ^) and the matrix

    Time to list some properties; we do not know yet if they are going to be useful but it is always good to list as many of them as possible:

    The simple XOR and bitwise XOR operators have the following properties:

    • Associativity: (A ⨁ B) ⨁ C = A ⨁ (B ⨁ C) = A ⨁ B ⨁ C
    • Commutativity: A ⨁ B = B ⨁ A
    • Neutral operand: A ⨁ 0 = A
    • Negation: A ⨁ A = 0
    • No overflow: in A ⨁ B = C, C simply needs to have the same number of bits as A and B.

    What makes the MITM algorithm possible is the associativity property, best expressed as:
    A ⨁ B ⨁ C ⨁ D = (A ⨁ B) ⨁ (C ⨁ D) = L ⨁ R.

    The above equality is the first of 2 things we had to find: it is what makes the MITM approach valid in the first place.
    While it is still unclear where we are going to save time (I will come back to it below), this at least should be determined before attempting to code a Meet-in-the-middle algorithm.


    I will loosely note A ⨁ B whether A and B are:

    • both single operands; then A ⨁ B is the result.
    • both a list of operands; then A ⨁ B is the list of all XOR combinations elements from A and B
      The size of the result is the product of the sizes of A and B.
    • a mix of 1 list and 1 single operand; then A ⨁ B is the list where all the elements of A are XOR'ed with B.
      The size of the result is the size of A.

    MITM - Overview

    The principle behind the MITM algorithm is to turn long chains of XOR operations into expressions involving only 2 operands, L ⨁ R, where L and R are themselves the results of XOR operations.
    As often, the best way to split the left and right operands is indeed in the middle (L and R are calculated from roughly the same number of matrix values), hence the name of the algorithm.

    I see 3 ways to split a matrix in halves.
    Marking the cells making the division with an x and considering they are the rightmost cells belonging to the left operand L, i.e. excluded from the right operand R:

    1. Left and right:

      R\C 0 1 2 3 4 5
      0 L L x R R R
      1 L L x R R R
      2 L L x R R R
      3 L L x R R R
      4 L L x R R R
      5 L L x R R R
    2. Up and bottom:

      R\C 0 1 2 3 4 5
      0 L L L L L L
      1 L L L L L L
      2 x x x x x x
      3 R R R R R R
      4 R R R R R R
      5 R R R R R R
    3. Diagonal

      R\C 0 1 2 3 4 5
      0 L L L L L x
      1 L L L L x R
      2 L L L x R R
      3 L L x R R R
      4 L x R R R R
      5 x R R R R R

    Let us count the number of half-paths on a 20x20 matrix. The left-and-right and the up-and-bottom splits are equivalent:

    With ... moves right and … moves down, path # leading to x Then … moves right and … moves down, path # leaving x
    9 0 1 9 19 6906900
    9 1 10 9 18 4686825
    9 2 55 9 17 3124550
    9 3 220 9 16 2042975
    9 4 715 9 15 1307504
    9 5 2002 9 14 817190
    9 6 5005 9 13 497420
    9 7 11440 9 12 293930
    9 8 24310 9 11 167960
    9 9 48620 9 10 92378
    9 10 92378 9 9 48620
    9 11 167960 9 8 24310
    9 12 293930 9 7 11440
    9 13 497420 9 6 5005
    9 14 817190 9 5 2002
    9 15 1307504 9 4 715
    9 16 2042975 9 3 220
    9 17 3124550 9 2 55
    9 18 4686825 9 1 10
    9 19 6906900 9 0 1

    Combining half-paths means: for each cell x, each of the half-paths leading to it can be concatenated to each of the half-paths leaving it.

    In terms of calculations, the total number of paths is the sum-product of the number of half-paths. If we do that, we get 35,345,263,800. This is the figure we already calculated, and it shows how the numbers hold up.

    Now, the figure we wanted to minimize was the number of half paths (i.e. the sum), and we get 40,060,020.
    It is a big improvement but when multiplied by the size of a 64-bit integer, we would need more than 256MB to store all the XOR'ed values.

    We could argue the left-and-right/up-and-bottom splits do are not truly done in the middle. Take the top cell in the middle column for instance, the distance from the top-left cell to it is shorter than the distance from it to the bottom-right cell.
    Let us see what the diagonal split does for the same 20x20 matrix:

    With ... moves right and … moves down, path # leading to x Then … moves right and … moves down, path # leaving x
    19 0 1 0 19 1
    18 1 19 1 18 19
    17 2 171 2 17 171
    16 3 969 3 16 969
    15 4 3876 4 15 3876
    14 5 11628 5 14 11628
    13 6 27132 6 13 27132
    12 7 50388 7 12 50388
    11 8 75582 8 11 75582
    10 9 92378 9 10 92378
    9 10 92378 10 9 92378
    8 11 75582 11 8 75582
    7 12 50388 12 7 50388
    6 13 27132 13 6 27132
    5 14 11628 14 5 11628
    4 15 3876 15 4 3876
    3 16 969 16 3 969
    2 17 171 17 2 171
    1 18 19 18 1 19
    0 19 1 19 0 1

    Again, the sum-product is equal to 35,345,263,800. What is interesting for us here is that the total count of the half paths is only 220 = 1,048,576 = 8MB worth of data.

    To be complete, let us see what a diagonal split on a non-square matrix should look like:

    R\C 0 1 2 3 4 5
    0 L L L L x R
    1 L L L x R R
    2 L L x R R R
    3 L x R R R R

    The magic finally happens

    So far, we have split paths in 2 halves and seen how the XOR value of each complete path can be reconstructed from both halves; however, it still feels we need to compare values 2 by 2.
    As previously said, recombining the half-paths with no strategy leads us to calculate 35,345,263,800 XOR'ed values. In other words, no time saved...

    This strategy is the second key you need to find when implementing MITM: unless you can find a trick to avoid having to recombine all the elements of the first half with all the elements of the second half, no time will be saved.

    In this case, the trick is to use the properties listed above to find the equivalence:
    A ⨁ B = C <=> A ⨁ C = B (I let you prove why it is so as an exercise).
    When applied to our 2 lists of operands L and R, and with k being the value we seek:

    We can apply that to our 2 lists of operands, L and R:

    • We want to find the number of times k appear in L ⨁ R (loosely noted L ⨁ R = k)
    • This is equivalent to doing the calculation L ⨁ k = R and counting their elements in common.

    As it happens, finding the elements that appear in 2 lists is a very simple task: sort them both then compare them in the same fashion as we do a merge in a merge sort.
    The diagonally split matrix greatly limits the length of L and R, then in average, each operand of L ⨁ k needs to be compared to only 1 operand of R (although note that a paths from L ⨁ k ending with the same XOR result, matching b paths from R must increase the count by a * b).

    Implementation

    It is finally time for some code.

    I would like to reiterate that you really should give it a try yourself before continuing to read. The best way for you to improve is by trying to do it yourself after all.

    Before diving into the solution, a few final considerations:

    • As per the problem statement, the result k (and the values inside the matrix) will be between 0 and 10^18.
      This translates into a unsigned long long/uint64_t variable (unsigned long may work on some compilers but the standard states unsigned long is only at least 32 bits).

    • The matrix size is dynamic, with 1 to 20 rows and 1 to 20 columns. Already presented in the alternative brute-force algorithm, the Matrix class takes care of it.

    • Unlike real production code, I will assume the input provided is always as described in the problem's text. In theory, we should check data is only positive integers within the expected bounds, that at each row, we receive the expected number of integers and finally, that we receive the expected number of rows, etc.

    • I will aim for minimal effort, i.e. changing as little things as possible from my brute-force algorithm above.
      Any performance improvement will be entirely due to the MITM algorithm but it may not pass the codeforces tests without a little bit of additional tuning. Some candidates for improvements:

      • An easy way to speed things up is to call std::vector<T>::reserve before writing the XOR combinations, so that no reallocation would take place during push_backs.
      • An easy way to further save memory is to calculate path 1 x cell at a time (the below code does all the half paths before it tries to combine them) and, of course, clean memory before switching to the next x cell.
      • There are probably additional improvements e.g. I rotate the matrix 180° so that I get 2 identical loops (start at the top left, make the bottom right become the top left, start at the top left again) but you'd get better performance without that cheap trick.
    #include <algorithm>
    #include <chrono>
    #include <iostream>
    #include <vector>
    
    //The matrix is not the core of the problem, so this simple class will do.
    class Matrix {
    public:
       Matrix(unsigned int rows, unsigned int columns) : rowCount(rows), colCount(columns), v(rows* columns) {}
       uint64_t* operator[](unsigned int r) { return &(v[r * colCount]); }
       //flip causes the matrix to do a 180° rotation
       void flip() { std::reverse(v.begin(), v.end()); }
    
    private:
       unsigned int rowCount, colCount;
       std::vector<uint64_t> v;
    };
    
    //Read the matrix from std::cin
    void read(Matrix& inputMatrix, unsigned int rows, unsigned int columns)
    {
       for (unsigned int r = 0; r < rows; ++r) {
           for (unsigned int c = 0; c < columns; ++c) {
               uint64_t value = 0; std::cin >> value;
               inputMatrix[r][c] = value;
           }
       }
    }
    
    int main(int argc, char* argv[])
    {
       using std::chrono::steady_clock;
       using std::chrono::duration_cast;
       using std::chrono::duration;
       using std::chrono::milliseconds;
    
       using XorValueList = std::vector<uint64_t>;
    
       unsigned int rows, columns;
       uint64_t expected;
       std::cin >> rows >> columns >> expected;
    
       // Construct the a matrix of the right size
       Matrix inputMatrix(rows, columns);
       // Fill the matrix with data from std::in
       read(inputMatrix, rows, columns);
    
    
       //The split position will determine where the left and right operands should meet.
       // If the matrix is not squared, it may not be inside it. 
       auto splitPosition = (rows + columns + 1) / 2;
    
       auto start = steady_clock::now();
    
       //This lambda is in charge of filling the 2 MITM operands
       //It is built for the left operand but as the matrix can be rotated 180°, works for the right operand too.
       auto fillXorOperand = [rows, columns, &splitPosition, &inputMatrix](std::vector<XorValueList>& xorOperand)
       {
           // xorRow represents the current row being processed. Since we can never go up, we do not need to remember what is above it.
           // It stores the list of XOR results for all the paths leading to the corresponding cell.
           std::vector<XorValueList> xorRow;
    
           // On the first row, only 1 path (horizontal) can reach each of the cells and there are no cells above.
           // The top left cell is a special case.
           xorRow.push_back({ inputMatrix[0][0] });
           // First row
           for (unsigned int c = 1; c < columns && c < splitPosition; ++c) {
               auto v = xorRow[c - 1][0];
               xorRow.push_back({ v ^ inputMatrix[0][c] });
           }
           if (columns >= splitPosition)
               xorOperand.push_back(xorRow[std::min(columns, splitPosition) - 1]);
    
           //Next rows
           for (unsigned int r = 1; r < rows && r < splitPosition; ++r) {
               // xorRow2 is the next row after xorRow.
               // It will be filled, then swapped with xorRow for the next loop / for after the loop
               std::vector<XorValueList> xorRow2;
               { // The first column is always a special case
                   auto v = xorRow[0][0];
                   xorRow2.push_back({ v ^ inputMatrix[r][0] });
               }
               //The other cells consist in taking the matrix value at that cell and combining it using xor with what we calculated on its left and above it.
               for (unsigned int c = 1; c < columns && c < splitPosition; ++c) {
                   xorRow2.push_back({ });
                   for (auto v : xorRow2[c - 1])
                       xorRow2[c].push_back(v ^ inputMatrix[r][c]);
                   for (auto v : xorRow[c])
                       xorRow2[c].push_back(v ^ inputMatrix[r][c]);
               }
               std::swap(xorRow, xorRow2);
               if (columns >= splitPosition - r)
                   xorOperand.push_back(xorRow[splitPosition - r - 1]);
           }
       };
    
       // XOR values for both MITM operands, representing where both areas meet in the matrix
       // Both variables will be initialized in the fillXorOperand lambda
       // Both variables are initialized with the number of vectors they will need in the end, representing how many cells are on the border.
       std::vector<XorValueList> xorLeft, xorRight;
       //xorLeft.resize(std::min(rows, columns), {}); xorRight.resize(std::min(rows, columns), {});
    
       // We fill the left operand. The first item in the vector will be the upmost, rightmost border cell of the matrix
       fillXorOperand(xorLeft);
    
       // We rotate the matrix 180° so that the bottom right cell comes to the top left
       // This is allowed because XOR is commutative (the operand order can be inverted).
       inputMatrix.flip();
       // Track where the diagonal moved during the flip
       splitPosition = rows + columns - splitPosition;
    
       for (unsigned int r = (splitPosition < columns ? 0 : splitPosition - columns); r < rows && r < splitPosition; ++r)
           inputMatrix[r][splitPosition - r - 1] = 0;
    
       // We fill the left operand. The first item in the vector will be the upmost, rightmost border cell of the flipped matrix, i.e. the downmost, leftmost border cell of the original matrix
       // We need to remember that when we do the correspondance between cells.
       fillXorOperand(xorRight);
    
       // The left operand is combined with the expected result, then all the vectors are sorted.
       for (auto& l : xorLeft) {
           for (auto& v : l)
               v = v ^ expected;
           std::sort(l.begin(), l.end());
       }
       // All the vectors in the right operand are sorted
       for (auto& r : xorRight) {
           std::sort(r.begin(), r.end());
       }
    
       auto lIter = xorLeft.cbegin();
       auto rIter = xorRight.crbegin();
    
       size_t result = 0;
       // We now need to "merge count" the cells shared on both operands, for the corresponding cells.
       // Here, merge count means count the identical cells in a loop similar to a merge-sort merging.
       for (; lIter != xorLeft.cend() && rIter != xorRight.crend(); ++lIter, ++rIter) {
           auto l = lIter->cbegin(), r = rIter->cbegin();
           while (l != lIter->cend() && r != rIter->cend()) {
               if (*l < *r)
                   ++l;
               else if (*r < *l)
                   ++r;
               else {
                   //If *l == *r, we need to count all consecutive equal values.
                   uint64_t value = *l;
                   size_t a = 0, b = 0;
                   for (; l != lIter->cend() && *l == value; ++l)
                       a += 1;
                   for (; r != rIter->cend() && *r == value; ++r)
                       b += 1;
                   result += a * b;
               }
           }
       }
       std::cout << result;
    
       auto end = steady_clock::now();
       auto ms_int = duration_cast<milliseconds>(end - start);
       std::cerr << "\nTime spent: " << ms_int.count() << ' ' << " ms\n" << std::endl;
    
       return 0;
    }