Search code examples
pythontensorflowpickle

unable to pickle keras.layers.StringLookup


minimal code to reproduce:

import tensorflow
import pickle
print(tensorflow.__version__)
pickle.dumps(tensorflow.keras.layers.StringLookup())

output:

2.8.0
---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-2-caa8edb2bfc5> in <module>()
      2 import pickle
      3 print(tensorflow.__version__)
----> 4 pickle.dumps(tensorflow.keras.layers.StringLookup())

1 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py in _numpy(self)
   1189       return self._numpy_internal()
   1190     except core._NotOkStatusException as e:  # pylint: disable=protected-access
-> 1191       raise core._status_to_exception(e) from None  # pylint: disable=protected-access
   1192 
   1193   @property

InvalidArgumentError: Cannot convert a Tensor of dtype resource to a NumPy array.

this is with tensorflow 2.8, but I encountered the same issue with other versions as well


Solution

  • a workaround mentioned in a few places is saving the weights and config separately

    using this wrapper class makes it possible to do this without too much boilerplate (the wrapper is picklelable, so any object containing it can be seamlessly be serialized)

    class KerasPickleWrapper:
        def __init__(self, obj=None):
            self.obj = obj
    
        def __getstate__(self):
            if self.obj is not None:
                return self.obj.__class__, self.obj.get_config(), self.obj.get_weights()
            else:
                return None
    
        def __setstate__(self, state):
            if state is not None:
                cls, config, weights = state
                self.obj = cls.from_config(config)
                self.obj.set_weights(weights)