I am looking for a number below 100'000 that fulfills a certain condition f(x)
.
So far I have written the following code:
#!/usr/bin/env python
import itertools
import multiprocessing.pool
paramlist = itertools.product("0123456789", repeat=5)
def function(word):
number = int(''.join(word))
# some code
with multiprocessing.pool.ThreadPool(processes=8) as pool:
pool.imap_unordered(function, paramlist)
pool.close()
pool.join()
Is there a way to improve the efficiency of the code?
The comment posted by jasonharper is "on point." Unless the work done by function
is sufficiently CPU-intensive, the time saved by running it in parallel will not make up for the additional overhead incurred by creating child processes.
If you are to use multiprocessing, I would simply break up the half-open interval [0, 100_000) into N smaller, non-overlapping half-open intervals where N is the number of CPU processors you have. I have chosen the f(x)
function (x ** 2
) such that the value of x
that satisfies a specific value of f(x)
(9801198001
) is skewed so that a serial solution will not find a result (99001
) until it has checked almost all possible values of x
. Even so, for such a simple function a multiprocessing solution runs 10 times more slowly than the serial solution.
If the f(x)
function is monotonically increasing or decreasing, then the serial solution can be further sped up using a binary search, which I have also included:
from multiprocessing import Pool, cpu_count
def f(x):
return x ** 2
def search(r):
for x in r:
if f(x) == 9801198001:
return x
return None
def main():
import time
# Parallel processing:
pool_size = cpu_count()
interval_size = 100_000 // pool_size
lower_bound = 0
args = []
for _ in range(pool_size - 1):
args.append(range(lower_bound, lower_bound + interval_size))
lower_bound += interval_size
# Final interval:
args.append(range(lower_bound, 100_000))
t = time.time()
with Pool(pool_size) as pool:
for result in pool.imap_unordered(search, args):
if result is not None:
break
# An implicit call to pool.terminate() will be called
# to terminate any remaining submitted tasks
elapsed = time.time() - t
print(f'result = {result}, parallel elapsed time = {elapsed}')
# Serial processing:
t = time.time()
result = search(range(100_000))
elapsed = time.time() - t
print(f'result = {result}, serial elapsed time = {elapsed}')
# Serial processing using a binary search
# for monotonically increasing function:
t = time.time()
lower = 0
upper = 100_000
result = None
while lower < upper:
x = lower + (upper - lower) // 2
sq = f(x)
if sq == 9801198001:
result = x
break
if sq < 9801198001:
lower = x + 1
else:
upper = x
elapsed = time.time() - t
print(f'result = {result}, serial binary search elapsed time = {elapsed}')
if __name__ == '__main__':
main()
Prints:
result = 99001, parallel elapsed time = 0.24248361587524414
result = 99001, serial elapsed time = 0.029256343841552734
result = 99001, serial binary search elapsed time = 0.0