Search code examples
pythonpython-3.xnumpydask

How to sum radially in a dask array?


I am trying to radially sum the values of a dask array where I retain the chuncked data and sum them for each radius. It may be useful to also normalize the sum to the total number of "pixels" summed.

This is the structure of the dask array: dask.array where, shape=(264, 256, 1500), dtype=int16, chunksize=(264, 256, 992), chunktype=numpy.ndarray data contains: intensity (counts) and Dimensions: x: distance (nm) of size (264,) y: distance (nm) of size (256,) energy_scale: energy (eV) of size (1500,) Dask Array Structure

I have tried this code to mask a circle and create a new array and it works great, but need to sum points radially.

import numpy as np

def cartesian_to_polar(x, y, center):

    x_rel = x - center[1]
    y_rel = y - center[0]
    radius = np.sqrt(x_rel**2 + y_rel**2)
    theta = np.arctan2(y_rel, x_rel)
    return radius, theta
def get_polar_coordinates_grid(center, max_radius, dimensions):
    x = np.arange(dimensions[1])
    y = np.arange(dimensions[0])
    xv, yv = np.meshgrid(x, y)
    rv, thetav = cartesian_to_polar(xv, yv, center)
    
    # Mask to select points within the specified radius
    mask = rv <= max_radius
    
    mask_3D = np.repeat(mask[:, :, np.newaxis], dimensions[2], axis=2)
    
    print('mask_3D shape',mask_3D.shape)
    
    return mask_3D, rv, thetav
def polar_profile(dataset, center, max_radius):
    print('center', center[0],center[1])

    mask_3D, rv, thetav = get_polar_coordinates_grid(center, max_radius, dataset.shape[:3])
    radial_dataset_array = np.repeat(rv[:, :, np.newaxis], dataset.shape[2], axis=2)

    masked_dataset = np.where(mask_3D, dataset, 0)

    masked_r_dataset = np.where(mask_3D, radial_dataset_array, 0)

From here I am not sure what the best route to index the pixels that need to be summed. An example with a simple numpy array would also be helpful.


Solution

  • It is a ctually far simpler than that. One approach is to do it this way (here I created a sample data, replace it with your own):

    import dask.array as da
    import numpy as np
    
    
    dask_data = da.random.random((264, 256, 1500), chunks=(264, 256, 992))
    
    center = (dask_data.shape[0] // 2, dask_data.shape[1] // 2)
    max_radius = min(center)
    
    x = da.arange(dask_data.shape[0])
    y = da.arange(dask_data.shape[1])
    xv, yv = da.meshgrid(x - center[0], y - center[1], indexing='ij')
    radii = da.sqrt(xv**2 + yv**2).astype(np.int32)
    
    def sum_radial(image, radii, max_radius):
        results = []
        for r in range(max_radius+1):
            mask = radii == r
            masked_data = da.where(mask[..., None], image, 0)
            sum_data = masked_data.sum(axis=(0, 1))
            count = mask.sum()
            results.append((sum_data, count))
        return results
    
    results = sum_radial(dask_data, radii, max_radius)
    final_results = [(res[0] / res[1]).compute() for res in results if res[1] > 0]
    

    which returns

    [array([0.56819324, 0.48555518, 0.07560507, ..., 0.57646385, 0.54231102,
           0.24472355]), array([0.69557989, 0.55792832, 0.37600816, ..., 0.59541941, 0.47944983,
           0.44399378]), array([0.54944245, 0.37145509, 0.57202096, ..., 0.5346811 , 0.46998448,
           0.48209734]), array([0.45247991, 0.48744803, 0.56118335, ..., 0.40804105, 0.57898455,
           0.48248958]), array([0.45528472, 0.43234064, 0.50973776, ..., 0.43357677, 0.54758867,
           0.48335612]), array([0.50817523, 0.54346835, 0.48714012, ..., 0.46681618, 0.58126478,
           0.41308635]), array([0.4513112 , 0.58265886, 0.51799577, ..., 0.50165861, 0.52393633,
           0.52736444]), array([0.47323004, 0.48509912, 0.50154742, ..., 0.44656244, 0.45017714,
           0.45944342]), array([0.51737857, 0.43069107, 0.50277041, ..., 0.50754802, 0.45124549,
           0.48241854]), array([0.47561275, 0.47823851, 0.45425869, ..., 0.48548346, 0.46774252,
           0.53712027]), array([0.45032079, 0.46940781, 0.54524277, ..., 0.54348768, 0.47486265,
           0.57205173]), array([0.41538367, 0.57499891, 0.53469115, ..., 0.49469812, 0.52803723,
           0.45382559]), array([0.46701844, 0.52337592, 0.52955901, ..., 0.50932645, 0.5619902 ,
           0.5217232 ]), array([0.51389081, 0.45033721, 0.52085882, ..., 0.46968577, 0.44170561,
           0.48609703]), array([0.52183791, 0.48536374, 0.46433651, ..., 0.47977694, 0.51830913,
    ...
           0.50629353]), array([0.50827807, 0.50630572, 0.48872231, ..., 0.50363937, 0.49413712,
           0.5133742 ]), array([0.51127133, 0.51176922, 0.48647271, ..., 0.49598711, 0.50950352,
           0.51000234]), array([0.51796942, 0.46370582, 0.48640914, ..., 0.51021386, 0.5104863 ,
           0.49745391])]