Search code examples
javaarraysdeeplearning4jnd4j

How to select a given set of indexes from an NDArray in ND4j similarly to numpy's arraydata[arrayIndex]?


I am developing a scientific application relying heavily on array manipulation in Java using ND4j (presently version 1.0.0-beta5). Throughout my pipeline, I have the need to dynamically select a non contiguous subset of a [2,195102] matrix (a few tens/hundreds of columns to be more precise). Any idea how to achieve this in this framework?

In short, I am trying to achieve this python/numpy operation:

import numpy as np
arrayData = np.array([[1, 5, 0, 6, 2, 0, 9, 0, 5, 2],
       [3, 6, 1, 0, 4, 3, 1, 4, 8, 1]])
arrayIndex = np.array((1,5,6))
res  = arrayData[:, arrayIndex]
# res value is
# array([[5, 0, 9],
#        [6, 3, 1]])

So far, I managed to select the desired column using the NDArray.getColumns function (along with the NDArray.data().asInt() from the indexArray to provide the values of the index). The problem is that the documentation explicitelly states, regarding the retrieval of information during a computation, "Note that THIS SHOULD NOT BE USED FOR SPEED" (see the documentation of NDArray.ToIntMatrix() to see the full message - different method, same operation).

I had a look at the different prototypes for NDArray.get() and none seem to fit the bill. I suppose that NDArray.getWhere() might work - if it, as I assume, only returns the elements that fulfil the condition - but have been, so far, unsuccessful in making use of it. The documentation is relatively light when it comes to explain the arguments/usage required.

Thank you all for your time and help :)

EDIT (04/11/2019): Some precision regarding what I have tried. I played around with NDArray.get() and made use of the indexes:

INDArray arrayData = Nd4j.create(new int[]
                    {1, 5, 0, 6, 2, 0, 9, 0, 5, 2,
                     3, 6, 1, 0, 4, 3, 1, 4, 8, 1},   new long[]{2, 10}, DataType.INT);
INDArray arrayIndex = Nd4j.create(new int[]{1, 5, 6}, new long[]{1,  3}, DataType.INT);

INDArray colSelection = null;

//index free version
colSelection = arrayData.getColumns(arrayIndex.toIntVector());
/*
* colSelection value is
* [[5, 0, 9],
*  [6, 3, 1]]
* but the toIntVector() call pulls the data from the back-end storage
* and re-inject them. That is presumed to be slow.
*  -   2 columns selected (arrayIndex = {1, 5}),        ==> 4001 ms for 100000 iterations
*  -   3 columns selected (arrayIndex = {1, 5, 6}),     ==> 5339 ms for 100000 iterations
*  -   4 columns selected (arrayIndex = {1, 5, 6 ,2}),  ==> 7016 ms for 100000 iterations
*/

//index version
colSelection = arrayData.get(NDArrayIndex.all(), NDArrayIndex.indices(arrayIndex.toLongVector()));
/*
* Same result, but same problem regarding toLongVector() this time around.
*  -   2 columns selected (arrayIndex = {1, 5}),        ==> 3200 ms for 100000 iterations
*  -   3 columns selected (arrayIndex = {1, 5, 6}),     ==> 4269 ms for 100000 iterations
*  -   4 columns selected (arrayIndex = {1, 5, 6 ,2}),  ==> 5252 ms for 100000 iterations
*/

//weird but functional version (that I just discovered)
colSelection = arrayData.transpose().get(arrayIndex); // the transpose operation is necessary to not hit an IllegalArgumentException: Illegal slice 5
// note that transposing the arrayIndex leads to an IllegalArgumentException: Illegal slice 6 (as it is trying to select the element at the line idx 1, column 5, depth 6, which does not exist)
/*
* colSelection value is
* [5, 6, 0, 3, 9, 1]
* The array is flattened... calling a reshape(arrayData.shape()[0],arrayIndex.shape()[1]) yields
* [[5, 6, 0],
*  [3, 9, 1]]
* which is wrong.
*/
colSelection = colSelection.reshape(arrayIndex.shape()[1],arrayData.shape()[0]).transpose();
/* yields the right result
* [[5, 0, 9],
*  [6, 3, 1]]
* While this seems to be the correct way to handle the memory the performance are low:
*  -   2 columns selected (arrayIndex = {1, 5}),        ==> 8225 ms for 100000 iterations
*  -   3 columns selected (arrayIndex = {1, 5, 6}),     ==> 8980 ms for 100000 iterations
*  -   4 columns selected (arrayIndex = {1, 5, 6 ,2}),  ==> 9453 ms for 100000 iterations
Plus, this is very roundabout method for such a "simple" operation
* if the repacking of the data is commented out, the timing become:
*  -   2 columns selected (arrayIndex = {1, 5}),        ==> 6987 ms for 100000 iterations
*  -   3 columns selected (arrayIndex = {1, 5, 6}),     ==> 7976 ms for 100000 iterations
*  -   4 columns selected (arrayIndex = {1, 5, 6 ,2}),  ==> 8336 ms for 100000 iterations
*/

Those speed my seem alright without knowing what machine I am running, but the equivalent python code yields:

  • 2 columns selected (arrayIndex = {1, 5}), ==> 171 ms for 100000 iterations
  • 3 columns selected (arrayIndex = {1, 5, 6}), ==> 173 ms for 100000 iterations
  • 4 columns selected (arrayIndex = {1, 5, 6 ,2}), ==> 173 ms for 100000 iterations

Those java implementations re at best 20 times slower than the python-numpy one.


Solution

  • org.nd4j.linalg.api.ndarray.INDArray arr = org.nd4j.linalg.factory.Nd4j.create(new double[][]{
                    {1, 5, 0, 6, 2, 0, 9, 0, 5, 2},
                    {3, 6, 1, 0, 4, 3, 1, 4, 8, 1}
            });
    
            org.nd4j.linalg.indexing.INDArrayIndex indices[] = {
                    org.nd4j.linalg.indexing.NDArrayIndex.all(),
                    new org.nd4j.linalg.indexing.SpecifiedIndex(1,5,6)
            };
    
            org.nd4j.linalg.api.ndarray.INDArray selected = arr.get(indices);
            System.out.println(selected);
        }
    

    This should work for you. This prints: SLF4J: Failed to load class "org.slf4j.impl.StaticLoggerBinder". SLF4J: Defaulting to no-operation (NOP) logger implementation SLF4J: See http://www.slf4j.org/codes.html#StaticLoggerBinder for further details.

    [[    5.0000,         0,    9.0000], 
     [    6.0000,    3.0000,    1.0000]]
    

    Process finished with exit code 0