Search code examples
javascaladeeplearning4jnd4j

ND4J arrays & their shapes: getting data into a list


Consider the following code, which uses the ND4J library to create a simpler version of the "moons" test data set:

val n = 100
val n1: Int = n/2
val n2: Int = n-n1
val outerX = Nd4j.getExecutioner.execAndReturn(new Cos(Nd4j.linspace(0, Math.PI, n1)))
val outerY = Nd4j.getExecutioner.execAndReturn(new Sin(Nd4j.linspace(0, Math.PI, n1)))
val innerX = Nd4j.getExecutioner.execAndReturn(new Cos(Nd4j.linspace(0, Math.PI, n2))).mul(-1).add(1)
val innerY = Nd4j.getExecutioner.execAndReturn(new Sin(Nd4j.linspace(0, Math.PI, n2))).mul(-1).add(1)
val X: INDArray = Nd4j.vstack(
  Nd4j.concat(1, outerX, innerX), // 1 x n
  Nd4j.concat(1, outerY, innerY)  // 1 x n
) // 2 x n
val y: INDArray = Nd4j.hstack(
  Nd4j.zeros(n1), // 1 x n1
  Nd4j.ones(n2)   // 1 x n2
) // 1 x n
println(s"# y shape: ${y.shape().toList}")                        // 1x100
println(s"# y data length: ${y.data().length()}")                 // 100
println(s"# X shape: ${X.shape().toList}")                        // 2x100
println(s"# X row 0 shape: ${X.getRow(0).shape().toList}")        // 1x100
println(s"# X row 1 shape: ${X.getRow(1).shape().toList}")        // 1x100
println(s"# X row 0 data length: ${X.getRow(0).data().length()}") // 200    <- !
println(s"# X row 1 data length: ${X.getRow(1).data().length()}") // 100

On the second to last line, X.getRow(0).data().length() is, surprisingly, 200 not 100. On inspection this is because the structure returned by data() contains the entire matrix, i.e. both rows, concatenated.

How do I get just the actual first row of the X matrix into a Java (or Scala) List? I could take just the first 100 items of the 200-element "first row", but that doesn't seem very elegant.


Solution

  • .data() gives you a straight row. See: http://nd4j.org/tensor

    The shape of an array is just a view of the underlying databuffer. I typically don't recommend doing what you're trying to do without good reason. All of the data is stored off heap. That copy is expensive.

    On heap is bad for doing any kind of math. The only use case here is integrations. I would suggest operating on the arrays directly as much as possible. Everything from serialization to indexing is handled for you.

    If you really need it for an integration of some kind, use guava and you can do it in one line: Doubles.asList(arr.data().dup().asDouble());

    where arr is your ndarray to operate on.