Search code examples
algorithmsumxor

Finding a number to xor with sequence elements to obtain given sum


I have recently come across the following problem: We're given an integer sequence x_i (x_i < 2^60) of n (n < 10^5) integers and an integer S (S < 2^60) find the smallest integer a such that the following holds:

formula.

For example:

x = [1, 2, 5, 10, 50, 100]
S = 242

Possible solutions for a are 21, 23, 37, 39, but the smallest is 21.

(1^21) + (2^21) + (5^21) + (10^21) + (50^21) + (100^21)
= 20 + 23 + 16 + 31 + 39 + 113 
= 242

Solution

  • One can build the result up bit by bit from the bottom. Starting with the lowest bit, try 0 and 1 as the lowest bit of a, and see if the lowest bit of the sum-xor matches the corresponding bit of S. Then try the next lowest bit, propagating any carry from the previous step.

    Following this algorithm, there may be 0, 1 or 2 choices for each bit of a, so in the worst case we may need to explore different branches and pick the one that gives the smallest result. To avoid exponential behavior, we cache previously seen results for the carry at a certain bit. That yields a worst-case complexity of O(kn) where k is the maximum number of bits in the result, and n is the maximum value of the carry given the input list is of length n.

    Here's some Python code that implements this:

    max_shift = 80
    
    def xor_sum0(xs, S, shift, carry, cache, sums):
        if shift >= max_shift:
            return 1e100 if carry else 0
        key = shift, carry
        if key in cache:
            return cache[key]
        best = 1e100
        for i in xrange(2):
            ss = sums[i][shift] + carry
            if ss & 1 == (S >> shift) & 1:
                best = min(best, i + 2 * xor_sum0(xs, S, shift + 1, ss >> 1, cache, sums))
        cache[key] = best
        return cache[key]
    
    def xor_sum(xs, S):
        sums = [
            [sum(((x >> sh) ^ i) & 1 for x in xs) for sh in xrange(max_shift)]
            for i in xrange(2)]
        return xor_sum0(xs, S, 0, 0, dict(), sums)
    

    In the case there's no solution, the code returns a large (>=1e100) floating point number.

    And here's a test that picks random values in the ranges you gave, picks a random a and computes S, and then solves. Note that sometimes the code finds a smaller a than the one that was used to compute S since values of a are not always unique.

    import random
    xs = [random.randrange(0, 1 << 61) for _ in xrange(random.randrange(10 ** 5))]
    a_original = random.randrange(1 << 61)
    S = sum(x ^ a_original for x in xs)
    print S
    print xs
    
    a = xor_sum(xs, S)
    assert a < 1e100
    print 'a:', a
    print 'original a:', a_original
    
    assert a <= a_original
    
    print 'S', S
    print 'SUM', sum(x^a for x in xs)
    
    assert sum(x^a for x in xs) == S