Search code examples
jax

DenseElementsAttr could not be constructed from the given buffer


I try to run a code written with JAX. At one part of the code, key for training set is defined as

key_train = random.PRNGKey(0).

Here the type of the key jaxlib.xla_extension.DeviceArray. Then in the following part, keys are defined as keys = random.split(key_train, N). Here N is an integer which is equal to 10000. At that part of the code it gives an error like:

DenseElementsAttr could not be constructed from the given buffer. This may mean that the Python buffer layout does not match that MLIR expected layout and is a bug.

Could you please help me about the error?

Edit: I try to run the code on Win10. Here (https://github.com/PredictiveIntelligenceLab/Physics-informed-DeepONets/blob/main/Antiderivative/DeepONet_antideriv.ipynb) you can find the code that I try to run. For simplicity you can try to run the code below as well. You will get the exact same error.

from jax import random
N=10000
key_train=random.PRNGKey(0)
keys=random.split(key_train, N)

Jax and Jaxlib versions are 0.3.5 with cuda 11


Solution

  • I had the same error. Deleting c:\python37\lib\site-packages\jaxlib\cuda_prng.py fixed the issue (replace the prefix by your python path). It could be cuda_prng.py was an old file.