I am interested about training a neural network using JAX. I had a look on tf.data.Dataset
, but it provides exclusively tf tensors. I looked for a way to change the dataset into JAX numpy array and I found a lot of implementations that use Dataset.as_numpy_generator()
to turn the tf tensors to numpy arrays. However I wonder if it is a good practice, as numpy arrays are stored in CPU memory and it is not what I want for my training (I use the GPU). So the last idea I found is to manually recast the arrays by calling jnp.array
but it is not really elegant (I am afraid about the copy in GPU memory). Does anyone have a better idea for that?
Quick code to illustrate:
import os
import jax.numpy as jnp
import tensorflow as tf
def generator():
for _ in range(2):
yield tf.random.uniform((1, ))
ds = tf.data.Dataset.from_generator(generator, output_types=tf.float32,
ds1 = ds.take(1).as_numpy_iterator()
ds2 = ds.skip(1)
for i, batch in enumerate(ds1):
for i, batch in enumerate(ds2):
# returns:
<class 'numpy.ndarray'> # not good
<class 'jaxlib.xla_extension.DeviceArray'> # good but not elegant
Both tensorflow and JAX have the ability to convert arrays to dlpack tensors without copying memory, so one way you can create a JAX array from a tensorflow array without copying the underlying data buffer is to do it via dlpack:
import numpy as np
import tensorflow as tf
import jax.dlpack
tf_arr = tf.random.uniform((10,))
dl_arr = tf.experimental.dlpack.to_dlpack(tf_arr)
jax_arr = jax.dlpack.from_dlpack(dl_arr)
np.testing.assert_array_equal(tf_arr, jax_arr)
By doing the round-trip to JAX, you can compare unsafe_buffer_pointer()
to ensure that the arrays point at the same buffer, rather than copying the buffer along the way:
def tf_to_jax(arr):
return jax.dlpack.from_dlpack(tf.experimental.dlpack.to_dlpack(arr))
def jax_to_tf(arr):
return tf.experimental.dlpack.from_dlpack(jax.dlpack.to_dlpack(arr))
jax_arr = jnp.arange(20.)
tf_arr = jax_to_tf(jax_arr)
jax_arr2 = tf_to_jax(tf_arr)
print(jnp.all(jax_arr == jax_arr2))
# True
print(jax_arr.unsafe_buffer_pointer() == jax_arr2.unsafe_buffer_pointer())
# True