Search code examples
pythonnumpynumpy-ndarray

Group elements in ndarray by index


I have an image dataset of a 1000 images, which I have created embeddings for. Each embeddings (512 embeddings for each image with a 256-d vector) is an ndarray of shape (512, 256), so the total array shape would be (1000, 512, 256).

Now, from each image (1000), I want to create a group of observation for the first embedding, of the 512 available, and collecting this embedding from each image. Then I want to do this for the second embedding, third, fourth, up to the 512th.

How would I go about creating these groups?


Solution

  • You can achieve that as follows:

    groups = []
    
    for i in range(512):
        # Select the i-th embedding from each image
        group = embeddings[:, i, :]
        groups.append(group)
    
    groups = np.array(groups)
    

    Another optimized solution:

    groups = np.array([embeddings[:, i, :] for i in range(512)])
    groups = np.transpose(groups, (1, 0, 2))