Search code examples
javamatrixxor

XOR Magic Rectangle


The Problem:

In a magic rectangle of dimensions m x n, every entry is the XOR of the row and column, zero-indexed. Example (8 x 5):

0 1 2 3 4 5 6 7
1 0 3 2 5 4 7 6
2 3 0 1 6 7 4 5
3 2 1 0 7 6 5 4
4 5 6 7 0 1 2 3

I need to find the sum of every entry, however, brute-forcing will not work as inputs range in the 10s of millions.

My work so far:

I found, that for a m x n matrix where the m or n is a power of 2, you can calculate the sum as sum_range(0, m-1) * n, where sum range literally just adds every number between the first and second input.

Things get interesting when neither m nor n is a power of 2.

You can split the m x n rectangle into a rectangle composed of sub m x n rectangles that are powers of two, like this: (15 x 15)

0, 1, 2, 3, 4, 5, 6, 7 | 8, 9, 10, 11 | 12, 13 | 14 | 
1, 0, 3, 2, 5, 4, 7, 6 | 9, 8, 11, 10 | 13, 12 | 15 | 
2, 3, 0, 1, 6, 7, 4, 5 | 10, 11, 8, 9 | 14, 15 | 12 | 
3, 2, 1, 0, 7, 6, 5, 4 | 11, 10, 9, 8 | 15, 14 | 13 | 
4, 5, 6, 7, 0, 1, 2, 3 | 12, 13, 14, 15 | 8, 9 | 10 | 
5, 4, 7, 6, 1, 0, 3, 2 | 13, 12, 15, 14 | 9, 8 | 11 | 
6, 7, 4, 5, 2, 3, 0, 1 | 14, 15, 12, 13 | 10, 11 | 8 | 
7, 6, 5, 4, 3, 2, 1, 0 | 15, 14, 13, 12 | 11, 10 | 9 | 
----------------------------------------------------
8, 9, 10, 11, 12, 13, 14, 15 | 0, 1, 2, 3 | 4, 5 | 6 | 
9, 8, 11, 10, 13, 12, 15, 14 | 1, 0, 3, 2 | 5, 4 | 7 | 
10, 11, 8, 9, 14, 15, 12, 13 | 2, 3, 0, 1 | 6, 7 | 4 | 
11, 10, 9, 8, 15, 14, 13, 12 | 3, 2, 1, 0 | 7, 6 | 5 | 
----------------------------------------------------
12, 13, 14, 15, 8, 9, 10, 11 | 4, 5, 6, 7 | 0, 1 | 2 | 
13, 12, 15, 14, 9, 8, 11, 10 | 5, 4, 7, 6 | 1, 0 | 3 | 
----------------------------------------------------
14, 15, 12, 13, 10, 11, 8, 9 | 6, 7, 4, 5 | 2, 3 | 0 | 
----------------------------------------------------

Then, using the formula I described above, you can get the sum of each of the squares along the diagonal (my explanation makes no sense so here is a picture): not enough reputation for embeds ._.

And that's were I am

I don't understand how to get the sum of the other parts, in terms of m and n, without brute-forcing. remember, m and n will be random integers, and they could be very very large

I see some pattern, like how the magic rectangle has some symmetry, and that the sides are the numbers 0 - (m-1) in order, however, I fail to come up with logic to translate that into code.

Pointers in the right direction would be appreciated. Not, however, looking for code, as this is a code wars problem and that's cheating

Reference Problem: https://www.codewars.com/kata/59568be9cc15b57637000054/train/java


Solution

  • We can divide the matrix into four sub metrices and leverage sum_range(start, m-1) * n over and over. For simplicity I'll explain using recursion, you can easily convert it to non recursive solution using queue.

    Let's consider small worse case scenario, a 7x7 matrix. Where, m=7(columns), n=7(rows). Sum 168.

    [0, 1, 2, 3, 4, 5, 6]
    [1, 0, 3, 2, 5, 4, 7]
    [2, 3, 0, 1, 6, 7, 4]
    [3, 2, 1, 0, 7, 6, 5]
    [4, 5, 6, 7, 0, 1, 2]
    [5, 4, 7, 6, 1, 0, 3]
    [6, 7, 4, 5, 2, 3, 0]
    

    We can divide this matrix into following 4 matrices:

    [0, 1, 2, 3] [4, 5, 6] [4, 5, 6, 7] [0, 1, 2]
    [1, 0, 3, 2] [5, 4, 7] [5, 4, 7, 6] [1, 0, 3]
    [2, 3, 0, 1] [6, 7, 4] [6, 7, 4, 5] [2, 3, 0]
    [3, 2, 1, 0] [7, 6, 5] 
    

    Here first matrix has dimensions equal to maximum possible power of 2 which is 4. Rest are just remaining 3 parts of the parent matrix.
    We can use the formula to get sum for first matrix: rowSum*rows -> (m*(m-1)/2)*n -> (4*3/2)*4 -> 24

    Next we consider second sub matrix. And divide in 4 parts using same strategy:

    [4, 5] [6] [6, 7] [4]
    [5, 4] [7] [7, 6] [5]
    

    Here, first sub matrix can be calculated using formula (rowSum - startSum)*rows->((m*(m-1)/2) - (start*(start-1)/2)) * n -> (6*5/2 - 4*3/2)*2 -> 18. Here, start = 4 xor 0.
    From above discussion we can create recursive function like this:

    int result = calculate(0,0,m,n);
    System.out.println("Total: " + result);
    
    int calculate(int m0, int n0, int m, int n) {
         System.out.print("calculating: (" +m0+","+n0 + ") (" + m +","+n+") ");
    
         if( (m0,n0) (m,n) is single dimensional array then)
                //use for loops do xor and return sum.
    
         //get maximum possible power of 2 that can be used to derive first sub matrix
         int pow = (int)(Math.log(m-m0) / Math.log(2));
         pow = (int)Math.pow(2, pow);
          
         //sum = sum(first sub matrix (m0,pow)(n0,pow))
    
         //sum += calculate(dimensions of second sub matrix)
         //sum += calculate(dimensions of third sub matrix)
         //sum += calculate(dimensions of fourth sub matrix)
        
         return sum;
    }
    
    int sum(int start, int m, int n) {
        m=m+start;
        int sum = (m*(m-1)/2);
        sum -= (start*(start-1)/2);
        sum *=n;
        System.out.println("sum="+sum);
        return sum;
    }
    

    Above code should yield output like:

    calculating: (0,0) (7,7) sum=24
    calculating: (4,0) (7,4) sum=18
    calculating: (6,0) (7,2) sum=13
    calculating: (4,2) (6,4) sum=26
    calculating: (6,2) (7,4) sum=9
    calculating: (0,4) (4,7) sum=66
    calculating: (4,4) (7,7) sum=2
    calculating: (6,4) (7,6) sum=5
    calculating: (4,6) (6,7) sum=5
    calculating: (6,6) (7,7) sum=0
    Total: 168
    

    To keep recursions minimum, make sure m>n. If it is not then swap the values.
    Note: As per the author's request, not providing the working code.