Search code examples
pythonnumpyrandompytorch

Generate n random 2D points within a valid region


I want to generate a fixed number of random-uniformly distributed xy points within a valid mask. The generated points should be continuous, i.e. can be "between pixels". I'm looking for an efficient solution since this will be part of a pytorch training loop.

Say I need 100 valid points, my current solution looks like this:

  1. Generate 500 random points
  2. Sample each point and check if the value is >0.5 (within the valid region)
  3. Throw away all invalid points and additionally all surplus points so I end up with 100

There must be a more efficient solution for this, right?

This is some example code to demo the problem:

import numpy as np
import cv2
import torch
import torch.nn.functional as F

# Create example mask
valid_mask = np.ones((300, 400), dtype=np.uint8) * 255
center = (valid_mask.shape[1] // 2, valid_mask.shape[0] // 2)
angle = 30
rotation_matrix = cv2.getRotationMatrix2D(center, angle, 1.0)
rotated_image = cv2.warpAffine(valid_mask, rotation_matrix, (valid_mask.shape[1], valid_mask.shape[0]), flags=cv2.INTER_NEAREST)

# display image
cv2.imshow('mask', rotated_image)
cv2.waitKey(0)
cv2.destroyAllWindows()

# generate 500 random points
num_points = 500
points = np.random.rand(num_points, 2) * 2 - 1  # Scale to [-1, 1]

# Convert image and points to tensors
rotated_image_tensor = torch.tensor(rotated_image, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255  # Shape (1, 1, H, W)
points_tensor = torch.tensor(points, dtype=torch.float32).unsqueeze(0).unsqueeze(0)  # Shape (1, 1, N, 2)

# Use grid_sample to sample from the image at each point
sampled_values = F.grid_sample(rotated_image_tensor, points_tensor, mode='bilinear', padding_mode='zeros', align_corners=False)

# Squeeze to remove unnecessary dimensions and get actual values
sampled_values = sampled_values.squeeze().numpy()

# Filter points where the sampled value is greater than 0.01
valid_points = points[sampled_values > 0.5]
print(len(valid_points))

# Display valid points
rotated_image = cv2.cvtColor(rotated_image, cv2.COLOR_GRAY2BGR)
for point in valid_points:
    x, y = point
    x = int((x + 1) * rotated_image.shape[1] / 2)
    y = int((y + 1) * rotated_image.shape[0] / 2)
    cv2.circle(rotated_image, (x, y), 1, (0, 255, 0), -1)

cv2.imshow('mask', rotated_image)
cv2.waitKey(0)

Any help would be appreciated.


Solution

  • You can directly treat your mask as the weights for a multinomial distribution and sample from it. Here is a minimal example:

    import numpy as np
    import matplotlib.pyplot as plt
    import torch
    
    generator = torch.Generator()
    generator.manual_seed(98237)
    
    # Create example mask
    mask = torch.zeros((256, 256))
    mask[32:187, 53:123] = 1.
    mask[150:220, 89:198] = 1.
    
    n_samples = 10_000
    
    # define the probability density function and sample
    pdf = mask / mask.sum()
    
    coords = torch.multinomial(pdf.flatten(), n_samples, replacement=True, generator=generator)
    
    x, y = torch.unravel_index(coords, pdf.shape)
    x_dithered = x + torch.rand(x.shape, generator=generator) - 0.5
    y_dithered= y + torch.rand(y.shape, generator=generator) - 0.5
    
    # Create a 2D histogram of the samples
    hist = np.histogram2d(x_dithered, y_dithered, bins=(np.arange(0, 256), np.arange(0, 256)))[0]
    
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    axes[0].imshow(mask, origin="lower", cmap="gray")
    axes[0].set_title("Mask")
    
    axes[1].imshow(hist, origin="lower", cmap="viridis")
    axes[1].set_title("Samples")
    

    Which gives:

    enter image description here

    The trick here is to convert the mask into a one dimensional probability mass function (PMF) and then sample the integer indices. Next convert those back to a two dimensional array indices using unravelling and add uniformly distributed noise to "smear" the position within the support of a single pixel.

    This approach generalizes to arbitrary PMFs, not only masks.

    I hope this helps!