Search code examples
javac++arraysjava-native-interfacecopy

Why is this JNI program not copying float values back to Java side?


I have this code:

#if defined(NOT_STANDALONE)
JNIEXPORT void JNICALL sumTraces
  (JNIEnv* env, jclass caller, jobjectArray jprestackTraces, jint nTracesIn, jobjectArray jsampleShifts,
  jobjectArray jstartIndices, jobjectArray jnSamples, jobjectArray jstackTracesOut,
  jobjectArray jpowerTracesOut, jint nTracesOut, jint samplesPerTrace) {

  jboolean isCopy;

  float* prestackTraces1D = (float*)malloc(nTracesIn * samplesPerTrace * sizeof(float));
  if (prestackTraces1D == NULL) Fatal("Could not malloc prestackTraces1D");
  int* sampleShifts1D = (int*)malloc(nTracesIn * nTracesOut * sizeof(int));
  if (sampleShifts1D == NULL) Fatal("Could not malloc sampleShifts1D");
  int* startIndices1D = (int*)malloc(nTracesIn * nTracesOut * sizeof(int));
  if (startIndices1D == NULL) Fatal("Could not malloc startIndices1D");
  int* nSamples1D = (int*)malloc(nTracesIn * nTracesOut * sizeof(int));
  if (nSamples1D == NULL) Fatal("Could not malloc nSamples1D");

  for (int in = 0; in < nTracesIn; in++) {

    jfloatArray j_prestack = (jfloatArray)env->GetObjectArrayElement(jprestackTraces, in);
    float* prestackTracesJava = (float*)env->GetPrimitiveArrayCritical(j_prestack, &isCopy);

    for (int s = 0; s < samplesPerTrace; s++) {
      int readIndex = s + (in * samplesPerTrace);
      prestackTraces1D[readIndex] = prestackTracesJava[s];
    }

    env->ReleasePrimitiveArrayCritical(j_prestack, prestackTracesJava, JNI_ABORT);
  }

  for (int out = 0; out < nTracesOut; out++) {

    jintArray j_shift = (jintArray)env->GetObjectArrayElement(jsampleShifts, out);
    int* sampleShiftsJava = (int*)env->GetPrimitiveArrayCritical(j_shift, &isCopy);
    jintArray j_start = (jintArray)env->GetObjectArrayElement(jstartIndices, out);
    int* startIndicesJava = (int*)env->GetPrimitiveArrayCritical(j_start, &isCopy);
    jintArray j_nSamps = (jintArray)env->GetObjectArrayElement(jnSamples, out);
    int* nSamplesJava = (int*)env->GetPrimitiveArrayCritical(j_nSamps, &isCopy);

    for (int in = 0; in < nTracesIn; in++) {
      int readIndex = in + (out * nTracesIn);
      sampleShifts1D[readIndex] = sampleShiftsJava[in];
      startIndices1D[readIndex] = startIndicesJava[in];
      nSamples1D[readIndex] = nSamplesJava[in];
    }

    env->ReleasePrimitiveArrayCritical(j_nSamps, nSamplesJava, JNI_ABORT);
    env->ReleasePrimitiveArrayCritical(j_start, startIndicesJava, JNI_ABORT);
    env->ReleasePrimitiveArrayCritical(j_shift, sampleShiftsJava, JNI_ABORT);
  }

  float* stackTracesOut1D = (float*)malloc(nTracesOut * samplesPerTrace * sizeof(float));
  if (stackTracesOut1D == NULL) Fatal("Could not malloc stackTracesOut1D");
  float* powerTracesOut1D = (float*)malloc(nTracesOut * samplesPerTrace * sizeof(float));
  if (powerTracesOut1D == NULL) Fatal("Could not malloc powerTracesOut1D");

  // Run the OpenCL program
  ComputeTraces(prestackTraces1D, stackTracesOut1D, powerTracesOut1D,
    startIndices1D, nSamples1D, sampleShifts1D,
    samplesPerTrace, nTracesIn, nTracesOut,
    0, 0, 1000);

  // Free the arrays that we can
  free(nSamples1D);
  free(startIndices1D);
  free(sampleShifts1D);
  free(prestackTraces1D);

  // Copy back the output for Java
  for (int out = 0; out < nTracesOut; out++) {
    jfloatArray j_stackOut = (jfloatArray)env->GetObjectArrayElement(jstackTracesOut, out);
    jfloatArray j_powerOut = (jfloatArray)env->GetObjectArrayElement(jpowerTracesOut, out);

    float* stackOutCopyBack = (float*)malloc(samplesPerTrace * sizeof(float));
    float* powerOutCopyBack = (float*)malloc(samplesPerTrace * sizeof(float));
    for (int s = 0; s < samplesPerTrace; s++) {
      int readIndex = s + (out * samplesPerTrace);
      stackOutCopyBack[s] = stackTracesOut1D[readIndex];
      powerOutCopyBack[s] = powerTracesOut1D[readIndex];
    }

    for (int s = 0; s < samplesPerTrace; s++) {
      printf("%d    %f/%f\n", s, stackOutCopyBack[s], powerOutCopyBack[s]);
    }

    env->ReleasePrimitiveArrayCritical(j_stackOut, stackOutCopyBack, 0);
    env->ReleasePrimitiveArrayCritical(j_powerOut, powerOutCopyBack, 0);

    free(stackOutCopyBack);
    free(powerOutCopyBack);
  }

  // Free the output arrays
  free(powerTracesOut1D);
  free(stackTracesOut1D);
}

The ComputeTraces(...) method fills the stackTracesOut1D and powerTracesOut1D arrays with values. I know these values are correct because of the printf statement inside the for loop near the end, I compare it with the values I want and they match. However when I check the Java side all values are zeroed out. Why is this JNI code not copying data back?

Keep in mind, as you can see in the code, that I have to condensed the 2D arrays that I am given via the parameters into 1D arrays in order to passed into the function. So before copying data back I take a part of the bigger 1D array and copy the values inside of it into a smaller array, which is one of the parameters in ReleasePrimitiveArrayCritical however the values aren't copied back.

Edit: Just to be clear I am talking about the lines about 10 lines up from the very end; env->ReleasePrimitiveArrayCritical(j_stackOut, stackOutCopyBack, 0); where I am using a 0.


Solution

  • So the problem was simply that I was forgetting to use GetPrimitiveArrayCritical(...) on the output arrays. So:

      for (int out = 0; out < nTracesOut; out++) {
        jfloatArray j_stackOut = (jfloatArray)env->GetObjectArrayElement(jstackTracesOut, out);
        jfloatArray j_powerOut = (jfloatArray)env->GetObjectArrayElement(jpowerTracesOut, out);
    
        float* stackOutCopyBack = (float*)malloc(samplesPerTrace * sizeof(float));
        float* powerOutCopyBack = (float*)malloc(samplesPerTrace * sizeof(float));
    
        for (int s = 0; s < samplesPerTrace; s++) {
          int readIndex = s + (out * samplesPerTrace);
          stackOutCopyBack[s] = stackTracesOut1D[readIndex];
          powerOutCopyBack[s] = powerTracesOut1D[readIndex];
        }
    
        env->ReleasePrimitiveArrayCritical(j_stackOut, stackOutCopyBack, 0);
        env->ReleasePrimitiveArrayCritical(j_powerOut, powerOutCopyBack, 0);
    
        free(stackOutCopyBack);
        free(powerOutCopyBack);
      }
    

    becomes:

      for (int out = 0; out < nTracesOut; out++) {
        jfloatArray j_stackOut = (jfloatArray)env->GetObjectArrayElement(jstackTracesOut, out);
        jfloatArray j_powerOut = (jfloatArray)env->GetObjectArrayElement(jpowerTracesOut, out);
    
        float* stackOutCopyBack = (float*)env->GetPrimitiveArrayCritical(j_stackOut, &isCopy);
        float* powerOutCopyBack = (float*)env->GetPrimitiveArrayCritical(j_powerOut, &isCopy);
    
        for (int s = 0; s < samplesPerTrace; s++) {
          int readIndex = s + (out * samplesPerTrace);
          stackOutCopyBack[s] = stackTracesOut1D[readIndex];
          powerOutCopyBack[s] = powerTracesOut1D[readIndex];
        }
    
        env->ReleasePrimitiveArrayCritical(j_stackOut, stackOutCopyBack, 0);
        env->ReleasePrimitiveArrayCritical(j_powerOut, powerOutCopyBack, 0);
      }
    

    It's also important that the free is removed because otherwise we are trying to remove an array from memory twice.