Search code examples
pythonstringbytestream

Checking for equality if either input can be `str` or `bytes`


I am trying to write a function that checks if two strings (with ASCII-only content) or bytes are equal.

Right now I have:

import typing as typ


def is_equal_str_bytes(
    a: typ.Union[str, bytes],
    b: typ.Union[str, bytes],
) -> bool:
    if isinstance(a, str):
        a = a.encode()
    if isinstance(b, str):
        b = b.encode()
    return a == b

This works with the any combination of str or bytes types, while the == operator will return False (rightfully) if the two types differ.

import itertools


ss = "ciao", b"ciao"
for a, b in itertools.product(ss, repeat=2):
    print(f"{a!r:<8} {b!r:<8} {is_equal_str_bytes(a, b)} {a == b}")
# 'ciao'   'ciao'   True True
# 'ciao'   b'ciao'  True False
# b'ciao'  'ciao'   True False
# b'ciao'  b'ciao'  True True

Is there a simpler / faster way?


Solution

  • Some benchmarks with random equal strings/bytes of a million characters (on TIO with Python 3.8 pre-release, but I got similar times with 3.10.2):

      186.88 us  s.encode()
      187.39 us  s.encode("utf-8")
      183.85 us  s.encode("ascii")
       94.62 us  b.decode()
       94.27 us  b.decode("utf-8")
      137.91 us  b.decode("ascii")
       79.93 us  s == s2
       82.69 us  b == b2
      182.72 us  s + "a"
      177.06 us  b + b"a"
        0.08 us  len(s)
        0.07 us  len(b)
        1.14 us  s[:1000].encode()
        0.97 us  b[:1000].decode()
        2.06 us  s[::1000].encode()
        1.45 us  b[::1000].decode()
        1.91 us  hash(s)
        1.56 us  hash(b)
      508.62 us  hash(s2)
      546.00 us  hash(b2)
        2.85 us  str(s)
     9142.59 us  str(b)
    13541.64 us  repr(s)
     9100.34 us  repr(b)
    

    Thoughts based on that:

    • I thought for simpler code, maybe we could apply str or repr to both of them and then somehow compare the resulting strings (like after removing b prefixes) but the benchmark shows that that would be very slow.
    • Getting the lengths is very cheap, so I'd compare those first. Return False if different, otherwise continue.
    • If you've hashed them already or are going to afterwards anyway, then you could compare the hashes (and return False if different, otherwise continue). See ASCII str / bytes hash collision for why equal ASCII string and ASCII bytes have the same hash. (But I'm not sure it's guaranteed by the language, so it might not be safe, I'm not sure). Note that hashing the first time is slow (see times for hashing s2/b2) but subsequent lookups of the stored hash is fast (see times for hashing s/b).
    • Decoding seems faster than encoding, so do that instead.
    • Only decode if the types differ (one is string and one is bytes), otherwise just use ==.
    • It's wasteful to decode a million bytes if already the first one is a mismatch. So might be worth it to decode/compare chunks of shorter length instead of the whole thing, or test some short prefix or cross section before testing the whole thing.

    So here's some potentially faster one using the above optimizations (not tested/benchmarked, partly because it depends on your data):

    import typing as typ
    
    def is_equal_str_bytes(
        a: typ.Union[str, bytes],
        b: typ.Union[str, bytes],
    ) -> bool:
        if len(a) != len(b):
            return False
        if hash(a) != hash(b):
            return False
        if type(a) is type(b):
            return a == b
        if isinstance(a, bytes):  # make a=str, b=bytes
            a, b = b, a
        if a[:1000] != b[:1000].decode():
            return False
        if a[::1000] != b[::1000].decode():
            return False
        return a == b.decode()
    

    My benchmark code:

    import os
    from timeit import repeat
    
    n = 10**6
    b = bytes(x & 127 for x in os.urandom(n))
    s = b.decode()
    assert hash(s) == hash(b)
    
    setup = '''
    from __main__ import s, b
    s2 = b.decode()  # Always fresh so it doesn't have a hash stored already 
    b2 = s.encode()
    assert s2 is not s and b2 is not b
    '''
    
    exprs = [
        's.encode()',
        's.encode("utf-8")',
        's.encode("ascii")',
        'b.decode()',
        'b.decode("utf-8")',
        'b.decode("ascii")',
        's == s2',
        'b == b2',
        's + "a"',
        'b + b"a"',
        'len(s)',
        'len(b)',
        's[:1000].encode()',
        'b[:1000].decode()',
        's[::1000].encode()',
        'b[::1000].decode()',
        'hash(s)',
        'hash(b)',
        'hash(s2)',
        'hash(b2)',
        'str(s)',
        'str(b)',
        'repr(s)',
        'repr(b)',
    ]
    
    for _ in range(3):
        for e in exprs:
            number = 100 if exprs.index(e) < exprs.index('hash(s)') else 1
            t = min(repeat(e, setup, number=number)) / number
            print('%8.2f us ' % (t * 1e6), e)
        print()