Consider the following Python function to be parallelized, which utilizes an georeferenced ndarray (assembled from rioxarray) and a shapefile. This function uses both these datasets to generate map plots with Matplotlib/CartoPy, the dependent variable being changes in map domain extent. Note that code to govern cosmetic alterations to the plot for titles, etc. has been removed to make this example as straightforward as possible:
def plotter(data, xgrid, ygrid, region) -> 'Graphics Plotter':
fig = plt.figure(figsize=(14,9))
gs = gridspec.GridSpec(ncols=1, nrows=2, width_ratios=[1], height_ratios=[0.15, 3.00])
gs.update(wspace=0.00, hspace=0.00)
bar_width = 0.40
ax1 = fig.add_subplot(gs[0, :])
ax1.axes.get_xaxis().set_visible(False)
ax1.axes.get_yaxis().set_visible(False)
for pos in ['top','right','left','bottom']:
ax1.spines[pos].set_visible(False)
ax2 = fig.add_subplot(gs[1, :], projection=crs.LambertConformal())
ax2.set_extent(region, crs=crs.LambertConformal())
ax2.set_adjustable('datalim')
im = ax2.pcolormesh(xgrid, ygrid, data.variable.data[0], cmap=cmap, norm=norm)
cb = plt.colorbar(im, ax=ax2, pad=0.01, ticks=ticks, aspect=80, orientation='horizontal')
ax2.add_feature(counties_feature, linewidth=0.45)
ax2.add_feature(states_feature, linewidth=1.25)
ax2.add_feature(canada_feature, linewidth=1.25)
This plotting function is passed data, grid extents, and region constraints from the main function, where the parallel execution is also defined. Note that da, x, y, and all shapefiles are static and are never altered through the duration of this script execution.
import multiprocess as mpr
import matplotlib as mpl
import cartopy.crs as crs
import cartopy.feature as cfeature
from cartopy.io.shapereader import Reader
from cartopy.feature import ShapelyFeature
import rioxarray as rxr
def main():
canada_feature = ShapelyFeature(Reader(canada).geometries(), crs.LambertConformal(), facecolor='none', edgecolor='black')
states_feature = ShapelyFeature(Reader(states).geometries(), crs.LambertConformal(), facecolor='none', edgecolor='black')
counties_feature = ShapelyFeature(Reader(counties).geometries(), crs.LambertConformal(), facecolor='none', edgecolor='black')
regions = pd.read_csv('/path/to/defined_regions.txt')
da = rxr.open_rasterio('path/to/somefile.tif', lock=False, mask_and_scale=True)
Y, X = da['y'], da['x']
x, y = np.meshgrid(da['x'], da['y'])
def parallel() -> 'Parallel Execution':
processes = []
for i, g in regions.iterrows():
pro = mpr.Process(target=plotter, args=(da, x, y, g['region']))
processes.extend([pro])
for p in processes:
p.start()
for p in processes:
p.join()
parallel()
The regions file contains 12 unique regions, which are each passed into a new process in the parallel function and executed. I'm noticing higher RAM usage when the pool executes, which I suspect is from inefficient utilization of memory when the ndarrays da, x, & y
and shapefiles are utilized by the parallel function.
Is there an effective way to share these data across the Multiprocess pool such that the RAM use is less expensive?
If you're on a POSIX OS (macOS, Linux) where you can set the Multiprocessing start method to fork
, you can take advantage of copy-on-write memory (aside from Python object headers where refcounts will get updated, but if your data is big or loading it takes a while, that's peanuts).
I've wrapped all of your shared data into a dataclass here; the idea is the parent process initializes it, and when the subprocesses fork, they can use the same global data (but will load separate copies if they need to; look at the indicatory print
when you try this).
from __future__ import annotations
import dataclasses
import os
import multiprocessing
import cartopy.crs as crs
from cartopy.io.shapereader import Reader
from cartopy.feature import ShapelyFeature
import rioxarray as rxr
@dataclasses.dataclass
class SharedData:
canada_feature: ShapelyFeature
states_feature: ShapelyFeature
counties_feature: ShapelyFeature
da: rxr.DataArray
da_meshgrid: np.ndarray
@classmethod
def load(cls):
proj = crs.LambertConformal()
canada_feature = ShapelyFeature(Reader(canada).geometries(), proj, facecolor="none", edgecolor="black")
states_feature = ShapelyFeature(Reader(states).geometries(), proj, facecolor="none", edgecolor="black")
counties_feature = ShapelyFeature(Reader(counties).geometries(), proj, facecolor="none", edgecolor="black")
da = rxr.open_rasterio("path/to/somefile.tif", lock=False, mask_and_scale=True)
da_meshgrid = np.meshgrid(da["x"], da["y"])
return cls(
canada_feature=canada_feature,
states_feature=states_feature,
counties_feature=counties_feature,
da=da,
da_meshgrid=da_meshgrid,
)
shared_data: SharedData | None = None
def plotter(region) -> "Graphics Plotter":
global shared_data
if shared_data is None:
print(f"Loading shared data in {os.getpid()}")
shared_data = SharedData.load()
data = shared_data.da
x, y = shared_data.da_meshgrid
# ...
def main():
global shared_data
shared_data = SharedData.load()
regions = pd.read_csv("/path/to/defined_regions.txt")
processes = [multiprocessing.Process(target=plotter, args=(g["region"],)) for i, g in regions.iterrows()]
for p in processes:
p.start()
for p in processes:
p.join()