Search code examples
pythonjsonnumpymsgpack

Decoding JSON with msgspec into NumPy arrays


I'm trying to utilize msgspec to encode and decode numpy data into json serialized objects. I've found lots of good resources on encoding the data and gotten my encoder to work no problem, but I can't get the data decoded back into the original format.

from dataclasses import dataclass
import numpy as np
import msgspec as ms
from traits.api import List, Array, Instance

def NumpyEncoder(obj):
    if isinstance(
        obj,
        (
            np.int_,
            np.intc,
            np.intp,
            np.int8,
            np.int16,
            np.int32,
            np.int64,
            np.uint8,
            np.uint16,
            np.uint32,
            np.uint64,
        ),
    ):
        return int(obj)
    elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)):
        return float(obj)
    elif isinstance(obj, (np.ndarray,)):
        return obj.tolist()
    return ms.json.encode(obj)

enc = ms.json.Encoder(enc_hook=NumpyEncoder)

@dataclass
class C:
    c1: np.ndarray = Array(dtype=np.float64)

@dataclass
class A:
    a1: list[C] = List(Instance(C))

c = C(np.ones(10))
a = A(c1 = c)

enc.encode(a)

which gives the correct serialized value of a. But how do I decode it correctly?

I've tried the following:

def NumpyEncoder(obj):
    if isinstance(
        obj,
        (
            np.int_,
            np.intc,
            np.intp,
            np.int8,
            np.int16,
            np.int32,
            np.int64,
            np.uint8,
            np.uint16,
            np.uint32,
            np.uint64,
        ),
    ):
        return int(obj)
    elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)):
        return float(obj)
    elif isinstance(obj, (np.ndarray,)):
        return dict(__ndarray__=obj.tolist(),dtype=str(obj.dtype))
    return ms.json.encode(obj)

class NumpyDecoder:
    def decoderHook(self, dct):
        """Decodes a previously encoded numpy ndarray with proper shape and 
           dtype.

        :param dct: (dict) json encoded ndarray
        :return: (ndarray) if input was an encoded ndarray
        """
        if isinstance(dct, dict) and '__ndarray__' in dct:
            return np.array(dct["__ndarray__"], dct['dtype'])
        return ms.json.decode(dct)

enc = ms.json.Encoder(enc_hook=NumpyEncoder)
dec = ms.json.Decoder(dec_hook=NumpyDecoder, type=A)

b = enc.encode(a)
print(dec.decode(b))

Which does not decode b back into an object of type A.

Thanks!


Solution

  • As you intend to use dataclass for inheriting purposes, I've added this answer for tackling that, please let me know if this works as well.

    import numpy as np
    import msgspec
    from dataclasses import dataclass, field
    from typing import List
    
    # Custom encoder function for numpy objects
    def NumpyEncoder(obj):
        if isinstance(
            obj,
            (
                np.int_,
                np.intc,
                np.intp,
                np.int8,
                np.int16,
                np.int32,
                np.int64,
                np.uint8,
                np.uint16,
                np.uint32,
                np.uint64,
            ),
        ):
            return int(obj)
        elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)):
            return float(obj)
        elif isinstance(obj, (np.ndarray,)):
            return {"__ndarray__": obj.tolist(), "dtype": str(obj.dtype)}
        return ms.json.encode(obj)
    
    # Custom decoder function for numpy objects
    def numpy_decoder_hook(type_, dct):
        if isinstance(dct, dict) and "__ndarray__" in dct:
            return np.array(dct["__ndarray__"], dtype=dct["dtype"])
        return dct
    
    # Define data classes
    @dataclass
    class C:
        c1: np.ndarray
    
    @dataclass
    class A:
        a1: List[C] = field(default_factory=list)
    
    # Creating encoder and decoder with hooks
    enc = msgspec.json.Encoder(enc_hook=NumpyEncoder)
    dec = msgspec.json.Decoder(dec_hook=numpy_decoder_hook, type=A)
    
    # Create instances
    c = C(c1=np.ones(10))
    a = A(a1=[c])
    
    # Encode
    encoded = enc.encode(a)
    print(f"Encoded: {encoded}")
    
    # Decode
    decoded = dec.decode(encoded)
    print(f"Decoded: {decoded}")
    print(f"Decoded a1[0].c1: {decoded.a1[0].c1}")