I have a flax struct dataclass containing a jax numpy array.
When I pickle dump this object and load it again, the array is not anymore a jax numpy array and is converted to a numpy array, here is the code to reproduce it:
import flax
import jax.numpy as jnp
import pickle
@flax.struct.dataclass
class A:
data: jnp.ndarray
a = A(data=jnp.zeros((2,2)))
print(a, type(a.data))
with open('file.pickle', 'wb') as handle:
pickle.dump(a, handle)
with open('file.pickle', 'rb') as handle:
loaded_a = pickle.load(handle)
print(loaded_a, type(loaded_a.data))
I don't want this behavior and I'd like it to keep its original type, is it possible ?
Update: this bug has been fixed in https://github.com/google/jax/pull/10659. Starting in the next release of JAX (v. 0.3.14) pickle
and deepcopy
should no longer convert JAX arrays to device arrays.
This is a known behavior in JAX; see https://github.com/google/jax/issues/2632
It's something that the library developers recognize as an unfortunate behavior, but a fix has not yet been prioritized. If you're interested, you might weigh-in on that issue.