Search code examples
pythonpython-2.7multiprocessingpython-multiprocessingpysal

parallel processing - nearest neighbour search using pysal python?


I have this data frame df1,

          id      lat_long
400743  2504043 (175.0976323, -41.1141412)
43203   1533418 (173.976683, -35.2235338)
463952  3805508 (174.6947496, -36.7437555)
1054906 3144009 (168.0105269, -46.36193)
214474  3030933 (174.6311167, -36.867717)
1008802 2814248 (169.3183615, -45.1859095)
988706  3245376 (171.2338968, -44.3884099)
492345  3085310 (174.740957, -36.8893026)
416106  3794301 (174.0106383, -35.3876921)
937313  3114127 (174.8436185, -37.80499)

I have constructed the tree for search here,

def construct_geopoints(s):
    data_geopoints = [tuple(x) for x in s[['longitude','latitude']].to_records(index=False)]
    tree = KDTree(data_geopoints, distance_metric='Arc', radius=pysal.cg.RADIUS_EARTH_KM)
    return tree

tree = construct_geopoints(actualdata)

Now, I am trying to search all the geopoints which are within 1KM of every geopoint in my data frame df1. Here is how I am doing,

dfs = []
for name,group in df1.groupby(np.arange(len(df1))//10000):
    s = group.reset_index(drop=True).copy()
    pts = list(s['lat_long'])
    neighbours = tree.query_ball_point(pts, 1)
    s['neighbours'] = pd.Series(neighbours)
    dfs.append(s)

output = pd.concat(dfs,axis = 0)

Everything here works fine, however I am trying to parallelise this task, since my df1 size is 2M records, this process is running for more than 8 hours. Can anyone help me on this? And another thing is, the result returned by query_ball_point is a list and so its throwing memory error when I am processing it for the huge amount of records. Any way to handle this.

EDIT :- Memory issue, look at the VIRT size.

enter image description here


Solution

  • It should be possible to parallelize your last segment of code with something like this:

    from multiprocessing import Pool
    ...
    
    def process_group(group):
        s = group[1].reset_index(drop=True)  # .copy() is implicit
        pts = list(s['lat_long'])
        neighbours = tree.query_ball_point(pts, 1)
        s['neighbours'] = pd.Series(neighbours)
        return s
    
    groups = df1.groupby(np.arange(len(df1))//10000)
    
    p = Pool(5)
    dfs = p.map(process_group, groups)
    
    output = pd.concat(dfs, axis=0)
    

    But watch out, because the multiprocessing library pickles all the data on its way to and from the workers, and that can add a lot of overhead for data-intensive tasks, possibly cancelling the savings due to parallel processing.

    I can't see where you'd be getting out-of-memory errors from. 8 million records is not that much for pandas. Maybe if your searches are producing hundreds of matches per row that could be a problem. If you say more about that I might be able to give some more advice.

    It also sounds like pysal may be taking longer than necessary to do this. You might be able to get better performance by using GeoPandas or "rolling your own" solution like this:

    1. assign each point to a surrounding 1-km grid cell (e.g., calculate UTM coordinates x and y, then create columns cx=x//1000 and cy=y//1000);
    2. create an index on the grid cell coordinates cx and cy (e.g., df=df.set_index(['cx', 'cy']));
    3. for each point, find the points in the 9 surrounding cells; you can select these directly from the index via df.loc[[(cx-1,cy-1),(cx-1,cy),(cx-1,cy+1),(cx,cy-1),...(cx+1,cy+1)], :];
    4. filter the points you just selected to find the ones within 1 km.