I experience a strange behavior of the multiprocessing module. Can anyone explain what is going on here?
The following MWE stalls (runs forever without error):
#!/usr/bin/env python3
import multiprocessing
import numpy as np
from skimage import io
from sklearn.cluster import KMeans
def create_model():
sampled_pixels = np.random.randint(0, 255, (800,3))
kmeans_model = KMeans(n_clusters=8, random_state=0).fit(sampled_pixels)
def process_image(test, test2):
image = np.random.randint(0, 255, (800,3))
kmeans_model = KMeans(n_clusters=8, random_state=0).fit(image)
image = kmeans_model.predict(image)
def main():
create_model()
with multiprocessing.Pool(1) as pool:
pool.apply_async(process_image, args=('test', 'test'))
pool.close()
pool.join()
if __name__ == "__main__":
main()
However, if I either remove the line create_model()
OR change
def process_image(test, test2)
# as well as
pool.apply_async(process_image, args=('test', 'test'))
to
def process_image(test)`
# and
pool.apply_async(process_image, args=('test'))
the code runs successfully, as it should, since the arguments as well as the function call create_model()
are completely redundant.
Appendix
> pip list
Package Version
------------- ---------
imageio 2.34.0
joblib 1.4.0
lazy_loader 0.4
networkx 3.3
numpy 1.26.4
packaging 24.0
pillow 10.3.0
pip 23.2.1
scikit-image 0.23.1
scikit-learn 1.4.2
scipy 1.13.0
threadpoolctl 3.4.0
tifffile 2024.2.12
> python --version
Python 3.12.2
I think I tracked it down to a bug / shortcoming in GNU OpenMP. In short, GNU OpenMP used by scikit-learn is not fork-safe:
If your program intends to become a background process [...], you must not use the OpenMP features before the fork. After OpenMP features are utilized, a fork is only allowed if the child process does not use OpenMP features, or it does so as a completely new process (such as after exec()).
A minimal way to trigger this is:
import multiprocessing
from sklearn.cluster import KMeans
def kmeans():
KMeans(n_clusters=2, random_state=0).fit([[1, 1], [2, 2]])
kmeans()
with multiprocessing.Pool(1) as pool:
pool.apply_async(kmeans)
pool.close()
pool.join()
For possible mitigations, see this link.