Search code examples
numpypytorchnumpy-ndarraynumpy-slicing

Take multiply slices in numpy/pytorch


I have a big one dimensional array X.shape = (10000,), and a vector of indices y = [0, 7, 9995].

I would like to get a matrix with rows

[
 X[0 : 100],
 X[7 : 107],
 concat(X[9995:], X[:95]),
]

That is, slices of length 100, starting at each index, with wrap-around.

I can do that with a python loop, but I'm wondering if there's a smarter batched way of doing it in pytorch or numpy, since my arrays can be quite large.


Solution

  • Quite simple, actually.

    1. For each element E in y, create a range from E to E + 100
    2. Concatenate all the ranges horizontally
    3. Modulo the resulting array by the length of X
    indexes = np.hstack([np.arange(v, v + 100) for v in y]) % X.shape[0]
    

    Output:

    >>> indexes
    array([   0,    1,    2,    3,    4,    5,    6,    7,    8,    9,   10,
             11,   12,   13,   14,   15,   16,   17,   18,   19,   20,   21,
             22,   23,   24,   25,   26,   27,   28,   29,   30,   31,   32,
             33,   34,   35,   36,   37,   38,   39,   40,   41,   42,   43,
             44,   45,   46,   47,   48,   49,   50,   51,   52,   53,   54,
             55,   56,   57,   58,   59,   60,   61,   62,   63,   64,   65,
             66,   67,   68,   69,   70,   71,   72,   73,   74,   75,   76,
             77,   78,   79,   80,   81,   82,   83,   84,   85,   86,   87,
             88,   89,   90,   91,   92,   93,   94,   95,   96,   97,   98,
             99,    7,    8,    9,   10,   11,   12,   13,   14,   15,   16,
             17,   18,   19,   20,   21,   22,   23,   24,   25,   26,   27,
             28,   29,   30,   31,   32,   33,   34,   35,   36,   37,   38,
             39,   40,   41,   42,   43,   44,   45,   46,   47,   48,   49,
             50,   51,   52,   53,   54,   55,   56,   57,   58,   59,   60,
             61,   62,   63,   64,   65,   66,   67,   68,   69,   70,   71,
             72,   73,   74,   75,   76,   77,   78,   79,   80,   81,   82,
             83,   84,   85,   86,   87,   88,   89,   90,   91,   92,   93,
             94,   95,   96,   97,   98,   99,  100,  101,  102,  103,  104,
            105,  106, 9995, 9996, 9997, 9998, 9999,    0,    1,    2,    3,
              4,    5,    6,    7,    8,    9,   10,   11,   12,   13,   14,
             15,   16,   17,   18,   19,   20,   21,   22,   23,   24,   25,
             26,   27,   28,   29,   30,   31,   32,   33,   34,   35,   36,
             37,   38,   39,   40,   41,   42,   43,   44,   45,   46,   47,
             48,   49,   50,   51,   52,   53,   54,   55,   56,   57,   58,
             59,   60,   61,   62,   63,   64,   65,   66,   67,   68,   69,
             70,   71,   72,   73,   74,   75,   76,   77,   78,   79,   80,
             81,   82,   83,   84,   85,   86,   87,   88,   89,   90,   91,
             92,   93,   94])
    

    Now just use index X with that:

    X[indexes]