Search code examples
c++armadillocereal

cereal + armadillo + json serialization


Does anyone have an example of Cereal based armadillo matrix serialization to JSON? Binary serialization below seems to be working.

Inside mat_extra_meat.hpp

template<class Archive, class eT>
typename std::enable_if<cereal::traits::is_output_serializable<cereal::BinaryData<eT>, Archive>::value, void>::type
save( Archive & ar, const Mat<eT>& m ) {
    uword n_rows = m.n_rows;
    uword n_cols = m.n_cols;
    ar( n_rows );
    ar( n_cols );
    ar( cereal::binary_data(
        reinterpret_cast< void * const >( const_cast< eT* >( m.memptr() ) ),
        static_cast< std::size_t >( n_rows * n_cols * sizeof( eT ) ) ) );
}

template<class Archive, class eT>
typename std::enable_if<cereal::traits::is_input_serializable<cereal::BinaryData<eT>, Archive>::value, void>::type
load( Archive & ar, Mat<eT>& m ) {
    uword n_rows;
    uword n_cols;
    ar( n_rows );
    ar( n_cols );

    m.resize( n_rows, n_cols );

    ar( cereal::binary_data(
        reinterpret_cast< void * const >( const_cast< eT* >( m.memptr() ) ),
        static_cast< std::size_t >( n_rows * n_cols * sizeof( eT ) ) ) );
}

Test with this:

int main( int argc, char** argv ) {

    arma::mat xx1 = arma::randn( 10, 20 );
    std::ofstream ofs( "test", std::ios::binary );
    cereal::BinaryOutputArchive o( ofs );
    o( xx1 );
    ofs.close();
    // Now load it.
    arma::mat xx2;
    std::ifstream ifs( "test", std::ios::binary );
    cereal::BinaryInputArchive i( ifs );
    i( xx2 );

}

Solution

  • You have two options for JSON serialization - you can take a quick and dirty approach that won't really be human readable, or you can make it human readable at the cost of increased serialization size and time.


    For the quick version, you can modify your existing code to use saveBinaryValue and loadBinaryValue, which exist within the text archives of cereal (JSON and XML).

    e.g.:

    ar.saveBinaryValue( reinterpret_cast<void * const>( const_cast< eT* >( m.memptr() ) ), 
                        static_cast<std::size_t>( n_rows * n_cols * sizeof( eT ) ) );
    

    and similarly for the load.

    This will base64 encode your data and write it as a string. You would of course need to specialize the function to only apply to text archives (or just JSON) within cereal.


    The alternative is to individually serialize each element. You have two choices again here, the first is to serialize as a JSON array (e.g. myarray: [1, 2, 3, 4, 5, ...]) or as a bunch of individual name-value-pairs: "array1" : "1", "array2": "2", ...

    The convention in cereal has been to use JSON arrays for dynamically re-sizable containers (e.g. vector), but because we're purely emphasizing readability with this example, I'll use arrays even though your armadillo matrix would not be something you would like users to be able to add or remove elements from using JSON:

    namespace arma
    {
      // Wraps a particular column in a class with its own serialization function.
      // This is necessary because cereal expects actual data to follow a size_tag, and can't
      // serialize two size_tags back to back without creating a new node (entering a new serialization function).
      //
      // This wrapper serves the purpose of creating a new node in the JSON serializer and allows us to
      // then serialize the size_tag, followed by the actual data
      template <class T>
      struct ColWrapper
      {
        ColWrapper(T && m, int c, int nc) : mat(std::forward<T>(m)), col(c), n_cols(nc) {}
        T & mat;
        int col;
        int n_cols;
    
        template <class Archive>
        void save( Archive & ar ) const
        {
          ar( cereal::make_size_tag( mat.n_rows ) );
          for( auto iter = mat.begin_col(col), end = mat.end_col(col); iter != end; ++iter )
            ar( *iter );
        }
    
        template <class Archive>
        void load( Archive & ar )
        {
          cereal::size_type n_rows;
    
          // Test to see if we need to resize the data
          ar( cereal::make_size_tag( n_rows ) );
          if( mat.n_rows != n_rows )
            mat.resize( n_rows, n_cols );
    
          for( auto iter = mat.begin_col(col), end = mat.end_col(col); iter != end; ++iter )
            ar( *iter );
        }
      };
    
      // Convenience function to make a ColWrapper
      template<class T> inline
      ColWrapper<T> make_col_wrapper(T && t, int c, int nc)
      {
        return {std::forward<T>(t), c, nc};
      }
    
      template<class Archive, class eT, cereal::traits::EnableIf<cereal::traits::is_text_archive<Archive>::value> = cereal::traits::sfinae>
      inline void save( Archive & ar, const Mat<eT>& m )
      {
        // armadillo stored in column major order
        uword n_rows = m.n_rows;
        uword n_cols = m.n_cols;
    
        // First serialize a size_tag for the number of columns. This will make expect a dynamic
        // sized container, which it will output as a JSON array. In reality our container is not dynamic,
        // but we're going for readability here.
        ar( cereal::make_size_tag( n_cols ) );
        for( auto i = 0; i < n_cols; ++i )
          // a size_tag must be followed up with actual serializations that create nodes within the JSON serializer
          // so we cannot immediately make a size_tag for the number of rows. See ColWrapper for more details
          ar( make_col_wrapper(m, i, n_cols) );
      }
    
      template<class Archive, class eT, cereal::traits::EnableIf<cereal::traits::is_text_archive<Archive>::value> = cereal::traits::sfinae>
      inline void load( Archive & ar, Mat<eT>& m )
      {
        // We're doing essentially the same thing here, but loading the sizes and performing the resize for the matrix
        // within ColWrapper
        cereal::size_type n_rows;
        cereal::size_type n_cols;
    
        ar( cereal::make_size_tag( n_cols ) );
        for( auto i = 0; i < n_cols; ++i )
          ar( make_col_wrapper(m, i, n_cols) );
      }
    } // end namespace arma
    

    Example program to run the above:

    int main(int argc, char* argv[])
    {
      std::stringstream ss;
      std::stringstream ss2;
    
      {
        arma::mat A = arma::randu<arma::mat>(4, 5);
        cereal::JSONOutputArchive ar(ss);
        ar( A );
      }
    
      std::cout << ss.str() << std::endl;
    
      {
        arma::mat A;
        cereal::JSONInputArchive ar(ss);
        ar( A );
        {
          cereal::JSONOutputArchive ar2(ss2);
          ar2( A );
        }
      }
    
      std::cout << ss2.str() << std::endl;
    
      return 0;
    }
    

    and its output:

    {
        "value0": [
            [
                0.786820954867802,
                0.2504803406880287,
                0.7106712289786555,
                0.9466678009609704
            ],
            [
                0.019271058195813773,
                0.40490214481616768,
                0.25131781792803756,
                0.02271243862792676
            ],
            [
                0.5206431525734917,
                0.34467030607918777,
                0.27419560360286257,
                0.561032100176393
            ],
            [
                0.14003945653337478,
                0.5438560675050177,
                0.5219157100717673,
                0.8570772835528213
            ],
            [
                0.49977436000503835,
                0.4193700240544483,
                0.7442805199715539,
                0.24916812957858262
            ]
        ]
    }
    {
        "value0": [
            [
                0.786820954867802,
                0.2504803406880287,
                0.7106712289786555,
                0.9466678009609704
            ],
            [
                0.019271058195813773,
                0.40490214481616768,
                0.25131781792803756,
                0.02271243862792676
            ],
            [
                0.5206431525734917,
                0.34467030607918777,
                0.27419560360286257,
                0.561032100176393
            ],
            [
                0.14003945653337478,
                0.5438560675050177,
                0.5219157100717673,
                0.8570772835528213
            ],
            [
                0.49977436000503835,
                0.4193700240544483,
                0.7442805199715539,
                0.24916812957858262
            ]
        ]
    }