Search code examples
pythonjax

How to create a jax array and store dictionaries inside it later?


I wish to create a Jax-based array and use this array later to store dictionaries inside it. Is it possible to do so?


Solution

  • JAX arrays cannot store dictionaries as items. They can only store items of simple numerical types, including:

    • integers: int8, int16, int32, int64
    • unsigned integers: uint8, uint16, uint32, uint64
    • floating point: bfloat16, float16, float32, float64
    • complex; complex64, complex128

    Additionally, several experimental narrow-width float and integer types from the ml_dtypes library have support on some hardware.

    Depending on your use-case, you may be able to use a struct-of-arrays pattern rather than an array-of-structs pattern, but it's hard to say whether this is applicable without more information about what you're trying to do.