Search code examples
pythonstringalgorithmstring-concatenation

How is str.join(iterable) method implemented in Python/ Linear time string concatenation


I am trying to implement my own str.join method in Python, e.g: ''.join(['aa','bbb','cccc']) returns 'aabbbcccc'. I know that string concatenation using the join method would result in linear (in the number of characters of the result) complexity, and I want to know how to do it, as using the '+' operator in a for loop would result in quadratic complexity e.g.:

res=''
for word in ['aa','bbb','cccc']:
  res = res +  word

As strings are immutable, this copies a new string at each iteration resulting in quadratic running time. However, I want to know how to do it in linear time or find how ''.join works exactly.

I could not find anywhere a linear time algorithm nor the implementation of str.join(iterable). Any help is much appreciated.


Solution

  • Joining str as actual str is a red herring and not what Python itself does: Python operates on mutable bytes, not the str, which also removes the need to know string internals. In specific, str.join converts its arguments to bytes, then pre-allocates and mutates its result.

    This directly corresponds to:

    1. a wrapper to encode/decode str arguments to/from bytes
    2. summing the len of elements and separators
    3. allocating a mutable bytesarray to construct the result
    4. copying each element/separator directly into the result
    # helper to convert to/from joinable bytes
    def str_join(sep: "str", elements: "list[str]") -> "str":
        joined_bytes = bytes_join(
            sep.encode(),
            [elem.encode() for elem in elements],
        )
        return joined_bytes.decode()
    
    # actual joining at bytes level
    def bytes_join(sep: "bytes", elements: "list[bytes]") -> "bytes":
        # create a mutable buffer that is long enough to hold the result
        total_length = sum(len(elem) for elem in elements)
        total_length += (len(elements) - 1) * len(sep)
        result = bytearray(total_length)
        # copy all characters from the inputs to the result
        insert_idx = 0
        for elem in elements:
            result[insert_idx:insert_idx+len(elem)] = elem
            insert_idx += len(elem)
            if insert_idx < total_length:
                result[insert_idx:insert_idx+len(sep)] = sep
                insert_idx += len(sep)
        return bytes(result)
    
    print(str_join(" ", ["Hello", "World!"]))
    

    Notably, while the element iteration and element copying basically are two nested loops, they iterate over separate things. The algorithm still touches each character/byte only thrice/once.