Search code examples
pythonnumpyscipyoverlappacking

Spheres random motion without overlap in Python


I want to move multiple times randomly a list of ~1 millions spheres. I have an initial list of non-overlapping spheres (array of centers) with the same radius.

Therefore, the goal is to apply small translations to each of them, taking care that, at the end, the spheres are still within some space boundaries, but also that no overlap occurs.

I tried myself to do it, by moving one by one the spheres and rejecting the motion if there is an overlap. However, even only testing for neighbors, the script is incredibly slow and would take hours just for N_motions=1.

Here is the code:

import numpy as np
from scipy.spatial import cKDTree

def in_cylinder(all_points, Rmax, Zmin, Zmax):
    all_points = np.atleast_2d(all_points)
    radial_distances = np.sqrt(all_points[:, 0]**2 + all_points[:, 1]**2)
    return (radial_distances <= Rmax) & (Zmin <= all_points[:, 2]) & (all_points[:, 2] <= Zmax)

def move_spheres(centers, r_spheres, motion_coef, N_motions):
    n_spheres = len(centers)
    updated_centers = np.copy(centers)
    motion_magnitude = motion_coef * r_spheres

    # Identify potential neighbors for each sphere
    for _ in range(N_motions):
        tree = cKDTree(centers)
        potential_neighbors = [tree.query_ball_point(center, 2*r_spheres + 2*motion_magnitude) for center in updated_centers]
        updated = np.zeros(n_spheres, dtype=bool)
        for i in range(n_spheres):
            # Generate a random direction
            direction = np.random.randn(3)
            direction /= np.linalg.norm(direction)

            # Generate a random magnitude
            magnitude = np.random.uniform(0, motion_magnitude)

            # Move the sphere
            new_center = updated_centers[i] + direction * magnitude

            # Check for space boundaries
            if in_cylinder(new_center, Rmax, Zmin, Zmax):
                neighbors_indices = [idx for idx in potential_neighbors[i] if idx != i]
                neighbors_centers = updated_centers[neighbors_indices]
                distances = np.linalg.norm(neighbors_centers - new_center, axis=1)
                overlap = np.any(distances < 2 * r_spheres)

                # Update the center if no overlap
                if not overlap:
                    updated_centers[i] = new_center
                    updated[i] = True
                    print(f'{sum(updated)}/{i+1}')
            else:
                print('out of cylinder')
        print(sum(updated), sum(updated)/n_spheres)
    return updated_centers

Would you have any recommendations to speed this up?


Solution

  • I found a set of optimizations that add up to about a ~5x improvement, depending on how you benchmark it. I don't think a 100x improvement is possible without completely changing the algorithm.

    I tried the following ideas to improve this:

    • Avoid running tree.query_ball_point() in a loop. query_ball_tree() accepts an array of points, and it is nearly always faster to query with the whole array rather than loop over the array. This is about 3x faster.
    • Use the multicore mode of tree.query_ball_point(). This gave about a 30% speedup.
    • Profile and use numba to replace hotspots. I wrote functions that use numba to compute the parts that seems to be slow when run under line_profiler.

    Code:

    import numpy as np
    from scipy.spatial import cKDTree
    import numba as nb
    import math
    
    @nb.njit()
    def in_cylinder(all_points, Rmax, Zmin, Zmax):
        radial_distances = all_points[0]**2 + all_points[1]**2
        return (radial_distances <= Rmax ** 2) & (Zmin <= all_points[2]) & (all_points[2] <= Zmax)
    
    
    @nb.njit()
    def generate_random_vector(max_magnitude):
        # Generate a random direction
        direction = np.random.randn(3)
        direction /= np.linalg.norm(direction)
    
        # Generate a random magnitude
        magnitude = np.random.uniform(0, max_magnitude)
        return direction * magnitude
    
    
    @nb.njit()
    def euclidean_distance(vec_a, vec_b):
        acc = 0.0
        for i in range(vec_a.shape[0]):
            acc += (vec_a[i] - vec_b[i]) ** 2
        return math.sqrt(acc)
    
    
    @nb.njit()
    def any_neighbor_in_range(new_center, all_neighbors, neighbors_indices, threshold, ignore_idx):
        for neighbor_idx in neighbors_indices:
            if neighbor_idx == ignore_idx:
                # This neighbor is myself - ignore this one
                continue
            distance = euclidean_distance(new_center, all_neighbors[neighbor_idx])
            if distance < threshold:
                return True
        return False
    
    
    def move_spheres(centers, r_spheres, motion_coef, N_motions):
        n_spheres = len(centers)
        updated_centers = np.copy(centers)
        motion_magnitude = motion_coef * r_spheres
    
        # Identify potential neighbors for each sphere
        for _ in range(N_motions):
            tree = cKDTree(centers)
            potential_neighbors_batch = tree.query_ball_point(updated_centers, 2*r_spheres + 2*motion_magnitude, workers=-1)
            updated = np.zeros(n_spheres, dtype=bool)
            for i in range(n_spheres):
                vector = generate_random_vector(motion_magnitude)
    
                # Move the sphere
                new_center = updated_centers[i] + vector
    
                # Check for space boundaries
                if in_cylinder(new_center, Rmax, Zmin, Zmax):
                    neighbors_indices = np.array(potential_neighbors[i])
                    overlap = any_neighbor_in_range(new_center, updated_centers, neighbors_indices, 2 * r_spheres, i)
    
                    # Update the center if no overlap
                    if not overlap:
                        updated_centers[i] = new_center
                        updated[i] = True
                        # print(f'{sum(updated)}/{i+1}')
                else:
                    # print('out of cylinder')
                    pass
            print(sum(updated), sum(updated)/n_spheres)
        return updated_centers