Search code examples
pythonserializationpickledillhuggingface-datasets

Can pickle/dill `foo` but not `lambda x: foo(x)`


I am applying some preprocessing to the CIFAR100 dataset

from datasets.load import load_dataset
from datasets import Features, Array3D
from transformers.models.vit.feature_extraction_vit import ViTFeatureExtractor


# Resampling & Normalization
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
dataset = load_dataset('cifar100', split='train[:100]')

features = Features({
    'pixel_values': Array3D(dtype="float32", shape=(3, 224, 224)),
    **dataset.features,
})


dataset = dataset.map(lambda batch, col_name: feature_extractor(batch[col_name]),
        features=features, fn_kwargs={'col_name': 'img'}, batched=True)

I got the following warning, which means datasets cannot cache the transformed dataset.

Reusing dataset cifar100 (/home/qys/.cache/huggingface/datasets/cifar100/cifar100/1.0.0/f365c8b725c23e8f0f8d725c3641234d9331cd2f62919d1381d1baa5b3ba3142)
Parameter 'function'=<function <lambda> at 0x7f3279f3eef0> of the transform datasets.arrow_dataset.Dataset._map_single couldn't be hashed properly, a random hash was used instead. Make sure your transforms and parameters are serializable with pickle or dill for the dataset fingerprinting and caching to work. If you reuse this transform, the caching mechanism will consider it to be different from the previous calls and recompute everything. This warning is only showed once. Subsequent hashing failures won't be showed.

Curiously, I can pickle/dill foo, but not lambda x: foo(x), despite the fact that they have exactly the same effect. I guess that's related to the problem?

>>> def foo(x): return x + 1
...
>>> Hasher.hash(foo)
'ff7fae499aa1d820'
>>> Hasher.hash(lambda x: foo(x))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/qys/Research/embedder/.venv/lib/python3.10/site-packages/datasets/fingerprint.py", line 237, in hash
    return cls.hash_default(value)
  File "/home/qys/Research/embedder/.venv/lib/python3.10/site-packages/datasets/fingerprint.py", line 230, in hash_default
    return cls.hash_bytes(dumps(value))
  File "/home/qys/Research/embedder/.venv/lib/python3.10/site-packages/datasets/utils/py_utils.py", line 564, in dumps
    dump(obj, file)
  File "/home/qys/Research/embedder/.venv/lib/python3.10/site-packages/datasets/utils/py_utils.py", line 539, in dump
    Pickler(file, recurse=True).dump(obj)
  File "/home/qys/Research/embedder/.venv/lib/python3.10/site-packages/dill/_dill.py", line 620, in dump
    StockPickler.dump(self, obj)
  File "/home/qys/.pyenv/versions/3.10.4/lib/python3.10/pickle.py", line 487, in dump
    self.save(obj)
  File "/home/qys/.pyenv/versions/3.10.4/lib/python3.10/pickle.py", line 560, in save
    f(self, obj)  # Call unbound method with explicit self
  File "/home/qys/Research/embedder/.venv/lib/python3.10/site-packages/datasets/utils/py_utils.py", line 862, in save_function
    dill._dill._save_with_postproc(
  File "/home/qys/Research/embedder/.venv/lib/python3.10/site-packages/dill/_dill.py", line 1153, in _save_with_postproc
    pickler.write(pickler.get(pickler.memo[id(dest)][0]))
KeyError: 139847629663936

I have also tried making the function accessible from the top level of a module, i.e.

preprocessor = lambda batch: feature_extractor(batch['img'])

dataset = dataset.map(preprocessor, features=features, batched=True)

However, it still doesn't work

>>> from datasets.fingerprint import Hasher
>>> preprocessor = lambda batch: feature_extractor(batch['img'])
>>> Hasher.hash(preprocessor)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/qys/Research/embedder/.venv/lib/python3.10/site-packages/datasets/fingerprint.py", line 237, in hash
    return cls.hash_default(value)
  File "/home/qys/Research/embedder/.venv/lib/python3.10/site-packages/datasets/fingerprint.py", line 230, in hash_default
    return cls.hash_bytes(dumps(value))
  File "/home/qys/Research/embedder/.venv/lib/python3.10/site-packages/datasets/utils/py_utils.py", line 564, in dumps
    dump(obj, file)
  File "/home/qys/Research/embedder/.venv/lib/python3.10/site-packages/datasets/utils/py_utils.py", line 539, in dump
    Pickler(file, recurse=True).dump(obj)
  File "/home/qys/Research/embedder/.venv/lib/python3.10/site-packages/dill/_dill.py", line 620, in dump
    StockPickler.dump(self, obj)
  File "/home/qys/.pyenv/versions/3.10.4/lib/python3.10/pickle.py", line 487, in dump
    self.save(obj)
  File "/home/qys/.pyenv/versions/3.10.4/lib/python3.10/pickle.py", line 560, in save
    f(self, obj)  # Call unbound method with explicit self
  File "/home/qys/Research/embedder/.venv/lib/python3.10/site-packages/datasets/utils/py_utils.py", line 862, in save_function
    dill._dill._save_with_postproc(
  File "/home/qys/Research/embedder/.venv/lib/python3.10/site-packages/dill/_dill.py", line 1153, in _save_with_postproc
    pickler.write(pickler.get(pickler.memo[id(dest)][0]))
KeyError: 140408024252096

Solution

  • In Python 3.9, pickle hashes the glob_ids dictionary in addition to the globs of a function. To make hashing deterministic when the globals are not in the same order, the order of glob_ids needs to be made deterministic. PR to fix: https://github.com/huggingface/datasets/pull/4516

    (Until merged, a temporary fix is to use an older version of dill:

    pip install "dill<0.3.5"
    

    see https://github.com/huggingface/datasets/issues/4506#issuecomment-1157417219)