Search code examples
pythonmatplotlibmultiprocessingpython-multiprocessingshared-memory

Multiprocess Sharing Static Data Between Process Jobs


Consider the following Python function to be parallelized, which utilizes an georeferenced ndarray (assembled from rioxarray) and a shapefile. This function uses both these datasets to generate map plots with Matplotlib/CartoPy, the dependent variable being changes in map domain extent. Note that code to govern cosmetic alterations to the plot for titles, etc. has been removed to make this example as straightforward as possible:

def plotter(data, xgrid, ygrid, region) -> 'Graphics Plotter':
    fig = plt.figure(figsize=(14,9))
    gs = gridspec.GridSpec(ncols=1, nrows=2, width_ratios=[1], height_ratios=[0.15, 3.00])
    gs.update(wspace=0.00, hspace=0.00)
    bar_width = 0.40

    ax1 = fig.add_subplot(gs[0, :])
    ax1.axes.get_xaxis().set_visible(False)
    ax1.axes.get_yaxis().set_visible(False)
    for pos in ['top','right','left','bottom']:
        ax1.spines[pos].set_visible(False)
    ax2 = fig.add_subplot(gs[1, :], projection=crs.LambertConformal())
    ax2.set_extent(region, crs=crs.LambertConformal())
    ax2.set_adjustable('datalim')

    im = ax2.pcolormesh(xgrid, ygrid, data.variable.data[0], cmap=cmap, norm=norm)
    cb = plt.colorbar(im, ax=ax2, pad=0.01, ticks=ticks, aspect=80, orientation='horizontal')
    ax2.add_feature(counties_feature, linewidth=0.45)
    ax2.add_feature(states_feature, linewidth=1.25)
    ax2.add_feature(canada_feature, linewidth=1.25)

This plotting function is passed data, grid extents, and region constraints from the main function, where the parallel execution is also defined. Note that da, x, y, and all shapefiles are static and are never altered through the duration of this script execution.

import multiprocess as mpr
import matplotlib as mpl
import cartopy.crs as crs
import cartopy.feature as cfeature
from cartopy.io.shapereader import Reader
from cartopy.feature import ShapelyFeature
import rioxarray as rxr

def main():
   canada_feature = ShapelyFeature(Reader(canada).geometries(), crs.LambertConformal(), facecolor='none', edgecolor='black')
   states_feature = ShapelyFeature(Reader(states).geometries(), crs.LambertConformal(), facecolor='none', edgecolor='black')
   counties_feature = ShapelyFeature(Reader(counties).geometries(), crs.LambertConformal(), facecolor='none', edgecolor='black')
   regions = pd.read_csv('/path/to/defined_regions.txt')
   da = rxr.open_rasterio('path/to/somefile.tif', lock=False, mask_and_scale=True)
   Y, X = da['y'], da['x']
   x, y = np.meshgrid(da['x'], da['y'])

   def parallel() -> 'Parallel Execution':
      processes = []
      for i, g in regions.iterrows():
      pro = mpr.Process(target=plotter, args=(da, x, y, g['region']))
      processes.extend([pro])
      for p in processes:
         p.start()
      for p in processes:
         p.join()

   parallel()

The regions file contains 12 unique regions, which are each passed into a new process in the parallel function and executed. I'm noticing higher RAM usage when the pool executes, which I suspect is from inefficient utilization of memory when the ndarrays da, x, & y and shapefiles are utilized by the parallel function.

Is there an effective way to share these data across the Multiprocess pool such that the RAM use is less expensive?


Solution

  • If you're on a POSIX OS (macOS, Linux) where you can set the Multiprocessing start method to fork, you can take advantage of copy-on-write memory (aside from Python object headers where refcounts will get updated, but if your data is big or loading it takes a while, that's peanuts).

    I've wrapped all of your shared data into a dataclass here; the idea is the parent process initializes it, and when the subprocesses fork, they can use the same global data (but will load separate copies if they need to; look at the indicatory print when you try this).

    from __future__ import annotations
    
    import dataclasses
    import os
    
    import multiprocessing
    import cartopy.crs as crs
    from cartopy.io.shapereader import Reader
    from cartopy.feature import ShapelyFeature
    import rioxarray as rxr
    
    
    @dataclasses.dataclass
    class SharedData:
        canada_feature: ShapelyFeature
        states_feature: ShapelyFeature
        counties_feature: ShapelyFeature
        da: rxr.DataArray
        da_meshgrid: np.ndarray
    
        @classmethod
        def load(cls):
            proj = crs.LambertConformal()
            canada_feature = ShapelyFeature(Reader(canada).geometries(), proj, facecolor="none", edgecolor="black")
            states_feature = ShapelyFeature(Reader(states).geometries(), proj, facecolor="none", edgecolor="black")
            counties_feature = ShapelyFeature(Reader(counties).geometries(), proj, facecolor="none", edgecolor="black")
            da = rxr.open_rasterio("path/to/somefile.tif", lock=False, mask_and_scale=True)
            da_meshgrid = np.meshgrid(da["x"], da["y"])
            return cls(
                canada_feature=canada_feature,
                states_feature=states_feature,
                counties_feature=counties_feature,
                da=da,
                da_meshgrid=da_meshgrid,
            )
    
    
    shared_data: SharedData | None = None
    
    
    def plotter(region) -> "Graphics Plotter":
        global shared_data
        if shared_data is None:
            print(f"Loading shared data in {os.getpid()}")
            shared_data = SharedData.load()
        data = shared_data.da
        x, y = shared_data.da_meshgrid
        # ...
    
    
    def main():
        global shared_data
        shared_data = SharedData.load()
        regions = pd.read_csv("/path/to/defined_regions.txt")
    
        processes = [multiprocessing.Process(target=plotter, args=(g["region"],)) for i, g in regions.iterrows()]
        for p in processes:
            p.start()
        for p in processes:
            p.join()