The problem is to decide if three nonempty sets of integers, represented as three arrays A, B and C (Python lists in the code below), have empty intersection.
I have written the following Python code implementing a decision algorithm for this problem:
def disjoint(A, B, C):
"""Solves (?) three-way disjointness problem in n*log(n) time.
"""
A = sorted(A[:])
B = sorted(B[:])
C = sorted(C[:])
i = j = k = 0
while i < len(A) and j < len(B) and k < len(C) and (A[i] != B[j] or B[j] != C[k]):
if A[i] <= B[j] and A[i] <= C[k]:
i += 1
elif B[j] <= A[i] and B[j] <= C[k]:
j += 1
else:
k += 1
if i < len(A) and j < len(B) and k < len(C):
return False
else:
return True
The algorithm seems to be correct (intuitive idea seems ok, and it does ok in the concrete tests), however I can't prove it formally.
So the question is: how to prove the above algorithm correct? What is a good, working loop invariant for the proof? (Or, if it is not correct, a counterexample. Note that Python here is irrelevant, I could have written it in pseudo-code. I'm interested in quasi-formal argument for this algorithm's correctness.)
(I tried to produce a formal proof myself; I have tested this idea by a practical implementation and it seems to work.)
The algorithm is correct. A loop invariant that works is Any common element must be >= max(A[i], B[j], C[k])
.
From that, it's easy to show that i
, j
, or k
, is only incremented when it points to a value that is too low, and that the invariant still holds after the increment, because the sort guarantees that there are no possible common elements between the old and new values.
As often happens, the correctness proof allows you to simplify your code, because you don't have to find the smallest of the 3 elements, just one that is smaller than the max:
def disjoint(A, B, C):
"""Solves (?) three-way disjointness problem in n*log(n) time.
"""
A = sorted(A[:])
B = sorted(B[:])
C = sorted(C[:])
i = j = k = 0
while i < len(A) and j < len(B) and k < len(C):
m = max(A[i], B[j], C[k])
if A[i] < m:
i += 1
elif B[j] < m:
j += 1
elif C[k] < m:
k += 1
else
return False
return True