Search code examples
bit-manipulationbitbit-shiftxorbitwise-xor

XOR of all the elements of form (start + 2 * i) where i ranges from 0 to n - 1


Background:

This question is actually on leetcode categorized as an easy problem with a solution of having time complexity of O(N).

But I am having hard time understanding solutions by other users of time complexity of O(1) using bit manipulation and some observations.

Problem statement:

Given an integer n and an integer start.

Define an array nums where nums[i] = start + 2*i (0-indexed) and n == nums.length.

Return the bitwise XOR of all elements of nums.

O(1) solution by a user:

class Solution {
public:
    int xorOperation(int n, int start) {
        int first = start & 1;
        start = start >> 1;
        if(start % 2 == 0){
            switch(n % 4){
                case 0: return 0;
                case 1: return ((start + n - 1) << 1) + first;
                case 2: return 2;
                case 3: return ((1 ^ (start + n - 1)) << 1) + first;
            } 
        } else {
            switch(n % 4){
                case 0: return (start ^ 1 ^ (start + n - 1)) << 1;
                case 1: return (start << 1) + first;
                case 2: return (start ^ (start + n - 1)) << 1;
                case 3: return ((start ^ 1) << 1) + first;
            } 
        }
        return 0; //unreachable
    }
};

He said that he used an observation that x ^ (x + 1) = 1 when x is an even number but I cannot understand how this information is useful.

Some information regarding concept that I might not be aware will be very useful.


Solution

  • Not sure why the code you posted is so complicated but I think I got a solution that is better explainable ( and simpler to boot )


    First, let's take a few observations

    1. you are always adding a 2 * number, so the lowest bit of start is just xor-ed n times with itself, which is the same as if it were and-ed with n&1 ( if n is odd, it stays as it is, if it's even, it's always 0 ). So, let's calculate it separately and shunt it right to the output and deal just with the rest

      int firstBit = start & n & 1;
      int startHi = start >> 1;
      // the formula now being just num[i]=startHi+i
      // ... here be computation ...
      //
      return firstBit | ( resultHi << 1 );
      
    2. xor is associative and its own inverse, so

      A^B^C^D^E^F^G^H = (A^B^C^D)^(E^F^G^H)

      (E^F^G^H) = (A^B^C^D)^(A^B^C^D^E^F^G^H)

      so we can solve the xor(m...n) in terms of xor(0...m) and xor(0...n) or, in terms of the start and n, we can

      (start+0^start+1^...start+n-1) = (0^1^2^....start-1) ^ (0^1^2^....start+n-1)

      let's thus define a simpler 1 param function xor0(k), the result will thus be

          int firstBit = start & n & 1;
          int startHi = start >> 1;
          int resultHi = xor0(startHi)^xor0(startHi+n);
          return firstBit | ( resultHi << 1 );
      
    3. if we are xor-ing numbers 0...k-1, look what the bits look like

      • the first bit goes like 0101 0101 0101 0101 0101,

        so the xor-ed value goes like 0110 0110 0110 0110 and we can see the pattern immediately and compute

        int firstBit = (k&2)>>1;

      • for all the other bits, take for example bit 2,

        it goes 0000 1111 0000 1111

        so the xor-ed value is 0000 1010 0000 1010

        or bit 3 0000 0000 1111 1111 0000 0000 1111 1111

        the xor-ed value being 0000 0000 1010 1010 0000 0000 1010 1010,

      we can also see the pattern right away: if the bit is 0 it stays 0, if the bit is 1, the output alternates between 1 for odd k and 0 for even k, so we can just int higherBits= (k-1) & ( (k&1)==0 ? 0 : -2 );

    and we thus have the final version ( in java, with self-check comparing it to the reference O(N) implementation )

    package xortest;
    
    public class XorTest {
    
        public static void main (String[] args) {
            for (int s = 0; s < 16; s++) {
                for (int n = 0; n < 16; n++) {
                    int r = xorReference(s, n);
                    int x = xorOperation(s, n);
                    System.out.println(String.format("s=%02d n=%02d ref=%08x val=%08x%s", s, n, r, x, x == r ? "" : " ERROR"));
                }
            }
        }
    
        static int xor0(int k) {
            int firstBit = (k & 2) >> 1;
            int higherBits = (k - 1) & ((k & 1) == 0 ? 0 : -2);
            return higherBits | firstBit;
        }
    
        static int xorOperation(int start, int n) {
            int firstBit = start & n & 1;
            int startHi = start >> 1;
            int resultHi = xor0(startHi) ^ xor0(startHi + n);
            return firstBit | (resultHi << 1);
        }
    
        static int xorReference(int start, int n) {
            int xor = 0;
            for (int i = 0; i < n; i++) {
                xor ^= start + 2 * i;
            }
            return xor;
        }
        
    }