Search code examples
arraysalgorithmmathxortriplet

Find the number of triplets i,j,k in an array such that the xor of elements indexed i to j-1 is equal to the xor of elements indexed j to k


For a given sequence of positive integers A1,A2,…,AN, you are supposed to find the number of triplets (i,j,k) such that Ai^Ai+1^..^Aj-1=Aj^Aj+1^..Ak where ^ denotes bitwise XOR. The link to the question is here: https://www.codechef.com/AUG19B/problems/KS1 All I did is try to find all subarrays with xor 0. The solution works but is quadratic time and thus too slow. This is the solution that I managed to get to.

for (int i = 0; i < arr.length; i++) {
            int xor = arr[i];
            for (int j = i + 1; j < arr.length; j++) {
                xor ^= arr[j];
                if (xor == 0) {
                    ans += (j - i);
                }
            }
        }
        finAns.append(ans + "\n");

Solution

  • Here's an O(n) solution based on CiaPan's comment under the question description:

    If xor of items at indices I through J-1 equals that from J to K, then xor from I to K equals zero. And for any such subarray [I .. K] every J between I+1 and K-1 makes a triplet satisfying the requirements. And xor from I to K equals (xor from 0 to K) xor (xor from 0 to I-1). So I suppose you might find xor-s of all possible initial parts of the sequence and look for equal pairs of them.

    The function f is the main method. brute_force is used for validation.

    Python 2.7 code:

    import random
    
    def brute_force(A):
      res = 0
    
      for i in xrange(len(A) - 1):
        left = A[i]
        for j in xrange(i + 1, len(A)):
          if j > i + 1:
            left ^= A[j - 1]
          right = A[j]
          for k in xrange(j, len(A)):
            if k > j:
              right ^= A[k]
            if left == right:
              res += 1
    
      return res
    
    def f(A):
      ps = [A[0]] + [0] * (len(A) - 1)
      for i in xrange(1, len(A)):
        ps[i] = ps[i- 1] ^ A[i]
    
      res = 0
      seen = {0: (-1, 1, 0)}
    
      for  i in xrange(len(A)):
        if ps[i] in seen:
          prev_i, i_count, count = seen[ps[i]]
          new_count = count + i_count * (i - prev_i) - 1
          res += new_count
          seen[ps[i]] = (i, i_count + 1, new_count)
        else:
          seen[ps[i]] = (i, 1, 0)
    
      return res
    
    for i in xrange(100):
      A = [random.randint(1, 10) for x in xrange(200)]
      f_A, brute_force_A = f(A), brute_force(A)
      assert f_A == brute_force_A
    print "Done"