I try to run a code written with JAX. At one part of the code, key for training set is defined as
key_train = random.PRNGKey(0)
.
Here the type of the key jaxlib.xla_extension.DeviceArray. Then in the following part, keys are defined as keys = random.split(key_train, N)
. Here N is an integer which is equal to 10000. At that part of the code it gives an error like:
DenseElementsAttr could not be constructed from the given buffer. This may mean that the Python buffer layout does not match that MLIR expected layout and is a bug.
Could you please help me about the error?
Edit: I try to run the code on Win10. Here (https://github.com/PredictiveIntelligenceLab/Physics-informed-DeepONets/blob/main/Antiderivative/DeepONet_antideriv.ipynb) you can find the code that I try to run. For simplicity you can try to run the code below as well. You will get the exact same error.
from jax import random
N=10000
key_train=random.PRNGKey(0)
keys=random.split(key_train, N)
Jax and Jaxlib versions are 0.3.5 with cuda 11
I had the same error. Deleting c:\python37\lib\site-packages\jaxlib\cuda_prng.py fixed the issue (replace the prefix by your python path). It could be cuda_prng.py was an old file.