Search code examples
pythonnumpyindexingndimage

python, dimension subset of ndimage using indices stored in another image


I have two images with the following dimensions, x, y, z:

img_a: 50, 50, 100

img_b: 50, 50

I'd like to reduce the z-dim of img_a from 100 to 1, grabbing just the value coincide with the indices stored in img_b, pixel by pixel, as indices vary throughout the image.

This should result in a third image with the dimension:

img_c: 50, 50

Is there already a function dealing with this issue?

thanks, peter


Solution

  • Ok updated with a vectorized method.

    Here is a duplicate question but the solution currently doesn't work when the row and column dimensions are not the same size.

    The code below has the method I added that explicitly creates the indices for look up purposes with numpy.indices() and then does the loop logic but in a vectorized way. It's slightly slower (2x) than the numpy.meshgrid() method but I think it's easier to understand and it also works with unequal row and column sizes.

    The timing is approximate but on my system I get:

    Meshgrid time: 0.319000005722
    Indices time: 0.704999923706
    Loops time: 13.3789999485
    

    -

    import numpy as np
    import time
    
    
    x_dim = 5000
    y_dim = 5000
    channels = 3
    
    # base data
    a = np.random.randint(1, 1000, (x_dim, y_dim, channels))
    b = np.random.randint(0, channels, (x_dim, y_dim))
    
    
    # meshgrid method (from here https://stackoverflow.com/a/27281566/377366 )
    start_time = time.time()
    i1, i0 = np.meshgrid(xrange(x_dim), xrange(y_dim), sparse=True)
    c_by_meshgrid = a[i0, i1, b]
    print('Meshgrid time: {}'.format(time.time() - start_time))
    
    # indices method (this is the vectorized method that does what you want)
    start_time = time.time()
    b_indices = np.indices(b.shape)
    c_by_indices = a[b_indices[0], b_indices[1], b[b_indices[0], b_indices[1]]]
    print('Indices time: {}'.format(time.time() - start_time))
    
    # loops method
    start_time = time.time()
    c_by_loops = np.zeros((x_dim, y_dim), np.intp)
    for i in xrange(x_dim):
        for j in xrange(y_dim):
            c_by_loops[i, j] = a[i, j, b[i, j]]
    print('Loops time: {}'.format(time.time() - start_time))
    
    
    # confirm correctness
    print('Meshgrid method matches loops: {}'.format(np.all(c_by_meshgrid == c_by_loops)))
    print('Loop method matches loops: {}'.format(np.all(c_by_indices == c_by_loops)))