I am trying to implement my own str.join
method in Python, e.g:
''.join(['aa','bbb','cccc'])
returns 'aabbbcccc'
. I know that string concatenation using the join method would result in linear (in the number of characters of the result) complexity, and I want to know how to do it, as using the '+'
operator in a for loop would result in quadratic complexity e.g.:
res=''
for word in ['aa','bbb','cccc']:
res = res + word
As strings are immutable, this copies a new string at each iteration resulting in quadratic running time. However, I want to know how to do it in linear time or find how ''.join
works exactly.
I could not find anywhere a linear time algorithm nor the implementation of str.join(iterable). Any help is much appreciated.
Joining str
as actual str
is a red herring and not what Python itself does: Python operates on mutable bytes
, not the str
, which also removes the need to know string internals. In specific, str.join
converts its arguments to bytes, then pre-allocates and mutates its result.
This directly corresponds to:
str
arguments to/from bytes
len
of elements and separatorsbytesarray
to construct the result# helper to convert to/from joinable bytes
def str_join(sep: "str", elements: "list[str]") -> "str":
joined_bytes = bytes_join(
sep.encode(),
[elem.encode() for elem in elements],
)
return joined_bytes.decode()
# actual joining at bytes level
def bytes_join(sep: "bytes", elements: "list[bytes]") -> "bytes":
# create a mutable buffer that is long enough to hold the result
total_length = sum(len(elem) for elem in elements)
total_length += (len(elements) - 1) * len(sep)
result = bytearray(total_length)
# copy all characters from the inputs to the result
insert_idx = 0
for elem in elements:
result[insert_idx:insert_idx+len(elem)] = elem
insert_idx += len(elem)
if insert_idx < total_length:
result[insert_idx:insert_idx+len(sep)] = sep
insert_idx += len(sep)
return bytes(result)
print(str_join(" ", ["Hello", "World!"]))
Notably, while the element iteration and element copying basically are two nested loops, they iterate over separate things. The algorithm still touches each character/byte only thrice/once.