Search code examples
pythonnumpyindexingnumpy-ndarraynumpy-slicing

Multi-dimensional indexing of numpy arrays along inner axis


  • I have a numpy array x with shape [4, 5, 3]
  • I have a 2D array of indices i with shape [4, 3], referring to indices along dimension 1 (of length 5) in x
  • I'd like to extract a sub-array y from x, with shape [4, 3], such that y[j, k] == x[j, i[j, k], k]
  • How do I do this?

Solution

  • I think the correct answer is as follows:

    y = x[np.arange(4).reshape(4, 1), i, np.arange(3).reshape(1, 3)]
    

    Example:

    import numpy as np
    
    rng = np.random.default_rng(0)
    
    x = np.arange(4 * 5 * 3)
    rng.shuffle(x)
    x = x.reshape(4, 5, 3)
    i = rng.integers(5, size=[4, 3])
    
    y = x[np.arange(4).reshape(4, 1), i, np.arange(3).reshape(1, 3)]
    
    print("x:", x, "i:", i, "y:", y, sep="\n")
    

    Output:

    x:
    [[[16 27 20]
      [ 8 42 34]
      [51  4 52]
      [57 10  2]
      [44 23 24]]
    
     [[43 11 35]
      [30 18 54]
      [ 3  1 55]
      [17 21 36]
      [ 0 28  6]]
    
     [[19 48 22]
      [26 37 46]
      [58 32 25]
      [53  9 38]
      [47 50 40]]
    
     [[13 12  7]
      [45 39 59]
      [ 5 49 14]
      [29 41 56]
      [33 15 31]]]
    i:
    [[1 3 4]
     [0 0 3]
     [1 2 0]
     [4 2 4]]
    y:
    [[ 8 10 24]
     [43 11 36]
     [26 32 22]
     [33 49 31]]
    

    (Rubber-duck debugging FTW)