Search code examples
pythonalgorithmtime-complexityrankinglexicographic-ordering

Can this algorithm be reversed in better than O(N^2) time?


Say I've got the following algorithm to biject bytes -> Natural:

def rank(s: bytes) -> int:
    k = 2**8
    result = 0
    offset = 0
    for i, w in enumerate(s):
        result *= k
        result += w
        offset += (k**i)
    return result + offset

The decoder (at least to the best of my limited abilities) is as follows:

def unrank(value: int) -> bytes:
    k = 2**8
    # 1. Get length
    import itertools
    offset = 0
    for length in itertools.count():  #! LOOP RUNS O(N) TIMES !#
        offset += (k**length)  #! LONG ADDITION IS O(N) !#
        if offset > value:
            value = value - (offset - k**length)
            break
    # 2. Get value
    result = bytearray(length)
    for i in reversed(range(length)):
        value, result[i] = divmod(value, k)  # (Can be done with bit shifts, ignore for complexity)
    return bytes(result)

Letting N ≈ len(bytes) ≈ log(int), this decoder clearly has a worst-case runtime of O(N^2). Granted, it performs well (<2s runtime) for practical cases (≤32KiB of data), but I'm still curious if it's fundamentally possible to beat that into something that swells less as the inputs get bigger.


# Example / test cases:

assert rank(b"") == 0
assert rank(b"\x00") == 1
assert rank(b"\x01") == 2
...
assert rank(b"\xFF") == 256
assert rank(b"\x00\x00") == 257
assert rank(b"\x00\x01") == 258
...
assert rank(b"\xFF\xFF") == 65792
assert rank(b"\x00\x00\x00") == 65793

assert unrank(0) == b""
assert unrank(1) == b"\x00"
assert unrank(2) == b"\x01"
# ...
assert unrank(256) == b"\xFF"
assert unrank(257) == b"\x00\x00"
assert unrank(258) == b"\x00\x01"
# ...
assert unrank(65792) == b"\xFF\xFF"
assert unrank(65793) == b"\x00\x00\x00"

assert unrank(2**48+1) == b"\xFE\xFE\xFE\xFE\xFF\x00"

Solution

  • It's clearer if you write your rank function like this:

    def rank(s: bytes) -> int:
        k = 2**8
        result = 0
        for w in s:
            result *= k
            result += w + 1
        return result
    

    ... that you can write unrank like this:

    def unrank(value: int) -> bytes:
        k = 2**8
        ret = bytearray(0)
        while value > 0:
            value -= 1
            value, digit = divmod(value, k)
            ret.append(digit)
        ret.reverse()
        return bytes(ret)
    

    (thanks for providing test cases)

    The above version of unrank is still quadratic due to the costs of operations on long integers, so here is a less readable version of the same algorithm that is actually O(n) in python:

    def unrank(value: int) -> bytes:
        ret = bytearray(value.to_bytes(value.bit_length()//8 + 1, 'little'))
        for i in range(len(ret)):
            byte = ret[i]
            # subtract 1
            j = i
            while byte == 0:
                ret[j] = 255
                j+=1
                if j >= len(ret):
                    break
                byte = ret[j]
            if byte == 0:
                # borrow went off the end
                ret = ret[:i]
                break
            ret[j] = byte-1
        ret.reverse()
        return bytes(ret)