Search code examples
pythontiming

Python string comparison doesn't short circuit?


The usual saying is that string comparison must be done in constant time when checking things like password or hashes, and thus, it is recommended to avoid a == b. However, I run the follow script and the results don't support the hypothesis that a==b short circuit on the first non-identical character.

from time import perf_counter_ns
import random

def timed_cmp(a, b):
    start = perf_counter_ns()
    a == b
    end = perf_counter_ns()
    return end - start

def n_timed_cmp(n, a, b):
    "average time for a==b done n times"
    ts = [timed_cmp(a, b) for _ in range(n)]
    return sum(ts) / len(ts)

def check_cmp_time():
    random.seed(123)
    # generate a random string of n characters
    n = 2 ** 8
    s = "".join([chr(random.randint(ord("a"), ord("z"))) for _ in range(n)])

    # generate a list of strings, which all differs from the original string
    # by one character, at a different position
    # only do that for the first 50 char, it's enough to get data
    diffs = [s[:i] + "A" + s[i+1:] for i in range(min(50, n))]

    timed = [(i, n_timed_cmp(10000, s, d)) for (i, d) in enumerate(diffs)]
    sorted_timed = sorted(timed, key=lambda t: t[1])

    # print the 10 fastest
    for x in sorted_timed[:10]:
        i, t = x
        print("{}\t{:3f}".format(i, t))

    print("---")
    i, t = timed[0]
    print("{}\t{:3f}".format(i, t))

    i, t = timed[1]
    print("{}\t{:3f}".format(i, t))

if __name__ == "__main__":
    check_cmp_time()

Here is the result of a run, re-running the script gives slightly different results, but nothing satisfactory.

# ran with cpython 3.8.3

6   78.051700
1   78.203200
15  78.222700
14  78.384800
11  78.396300
12  78.441800
9   78.476900
13  78.519000
8   78.586200
3   78.631500
---
0   80.691100
1   78.203200

I would've expected that the fastest comparison would be where the first differing character is at the beginning of the string, but it's not what I get. Any idea what's going on ???


Solution

  • There's a difference, you just don't see it on such small strings. Here's a small patch to apply to your code, so I use longer strings, and I do 10 checks by putting the A at a place, evenly spaced in the original string, from the beginning to the end, I mean, like this:

    A_______________________________________________________________
    ______A_________________________________________________________
    ____________A___________________________________________________
    __________________A_____________________________________________
    ________________________A_______________________________________
    ______________________________A_________________________________
    ____________________________________A___________________________
    __________________________________________A_____________________
    ________________________________________________A_______________
    ______________________________________________________A_________
    ____________________________________________________________A___
    
    @@ -15,13 +15,13 @@ def n_timed_cmp(n, a, b):
     def check_cmp_time():
         random.seed(123)
         # generate a random string of n characters
    -    n = 2 ** 8
    +    n = 2 ** 16
         s = "".join([chr(random.randint(ord("a"), ord("z"))) for _ in range(n)])
    
         # generate a list of strings, which all differs from the original string
         # by one character, at a different position
         # only do that for the first 50 char, it's enough to get data
    -    diffs = [s[:i] + "A" + s[i+1:] for i in range(min(50, n))]
    +    diffs = [s[:i] + "A" + s[i+1:] for i in range(0, n, n // 10)]
    
         timed = [(i, n_timed_cmp(10000, s, d)) for (i, d) in enumerate(diffs)]
         sorted_timed = sorted(timed, key=lambda t: t[1])
    

    and you'll get:

    0   122.621000
    1   213.465700
    2   380.214100
    3   460.422000
    5   694.278700
    4   722.010000
    7   894.630300
    6   1020.722100
    9   1149.473000
    8   1341.754500
    ---
    0   122.621000
    1   213.465700
    

    Note that with your example, with only 2**8 characters, it's already noticable, apply this patch:

    @@ -21,7 +21,7 @@ def check_cmp_time():
         # generate a list of strings, which all differs from the original string
         # by one character, at a different position
         # only do that for the first 50 char, it's enough to get data
    -    diffs = [s[:i] + "A" + s[i+1:] for i in range(min(50, n))]
    +    diffs = [s[:i] + "A" + s[i+1:] for i in [0, n - 1]]
     
         timed = [(i, n_timed_cmp(10000, s, d)) for (i, d) in enumerate(diffs)]
         sorted_timed = sorted(timed, key=lambda t: t[1])
    

    to only keep the two extreme cases (first letter change vs last letter change) and you'll get:

    $ python3 cmp.py
    0   124.131800
    1   135.566000
    

    Numbers may vary, but most of the time test 0 is a tad faster that test 1.

    To isolate more precisely which caracter is modified, it's possible as long as the memcmp does it character by character, so as long as it does not use integer comparisons, typically on the last character if they get misaligned, or on really short strings, like 8 char string, as I demo here:

    from time import perf_counter_ns
    from statistics import median
    import random
    
    
    def check_cmp_time():
        random.seed(123)
        # generate a random string of n characters
        n = 8
        s = "".join([chr(random.randint(ord("a"), ord("z"))) for _ in range(n)])
    
        # generate a list of strings, which all differs from the original string
        # by one character, at a different position
        # only do that for the first 50 char, it's enough to get data
        diffs = [s[:i] + "A" + s[i + 1 :] for i in range(n)]
    
        values = {x: [] for x in range(n)}
        for _ in range(10_000_000):
            for i, diff in enumerate(diffs):
                start = perf_counter_ns()
                s == diff
                values[i].append(perf_counter_ns() - start)
    
        timed = [[k, median(v)] for k, v in values.items()]
        sorted_timed = sorted(timed, key=lambda t: t[1])
    
        # print the 10 fastest
        for x in sorted_timed[:10]:
            i, t = x
            print("{}\t{:3f}".format(i, t))
    
        print("---")
        i, t = timed[0]
        print("{}\t{:3f}".format(i, t))
    
        i, t = timed[1]
        print("{}\t{:3f}".format(i, t))
    
    
    if __name__ == "__main__":
        check_cmp_time()
    
    

    Which gives me:

    1   221.000000
    2   222.000000
    3   223.000000
    4   223.000000
    5   223.000000
    6   223.000000
    7   223.000000
    0   241.000000
    

    The differences are so small, Python and perf_counter_ns may no longer be the right tools here.