Search code examples

unexpected transformer's dataset structure after set_transform or with_transform

I am using the feature extractor from ViT like explained here.

And noticed a weird behaviour I cannot fully understand.

After loading the dataset as in that colab notebook, I see:


{'image_file_path': Value(dtype='string', id=None),  'image':
Image(mode=None, decode=True, id=None),  'labels':
ClassLabel(names=['angular_leaf_spot', 'bean_rust', 'healthy'],

And we can assess the features in both ways:


[0, 0, 0, 0, 0]


'image': [<PIL.JpegImagePlugin.JpegImageFile image mode=RGB
size=500x500>,   <PIL.JpegImagePlugin.JpegImageFile image mode=RGB
size=500x500>],  'labels': [0, 0]}

But after

from transformers import ViTFeatureExtractor

model_name_or_path = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)
ds = load_dataset('beans')

def transform(example_batch):
    inputs = feature_extractor([x for x in example_batch['image']], return_tensors='pt')
    inputs['labels'] = example_batch['labels']
    return inputs

prepared_ds = ds.with_transform(transform)

We see the features are kept:


{'image_file_path': Value(dtype='string', id=None),  'image':
Image(mode=None, decode=True, id=None),  'labels':
ClassLabel(names=['angular_leaf_spot', 'bean_rust', 'healthy'],


{'pixel_values': tensor([[[[-0.5686, -0.5686, -0.5608,  ..., -0.0275, 
0.1843, -0.2471],
[-0.5843, -0.5922, -0.6078,  ...,  0.2627,  0.1608,  0.2000]],

         [[-0.7098, -0.7098, -0.7490,  ..., -0.3725, -0.1608, -0.6000],
          [-0.8824, -0.9059, -0.9216,  ..., -0.2549, -0.2000, -0.1216]]],

        [[[-0.5137, -0.4902, -0.4196,  ..., -0.0275, -0.0039, -0.2157],
          [-0.5216, -0.5373, -0.5451,  ..., -0.1294, -0.1529, -0.2627]],

         [[-0.1843, -0.2000, -0.1529,  ...,  0.2157,  0.2078, -0.0902],
          [-0.7725, -0.7961, -0.8039,  ..., -0.3725, -0.4196, -0.5451]],

         [[-0.7569, -0.8510, -0.8353,  ..., -0.3255, -0.2706, -0.5608],
          [-0.5294, -0.5529, -0.5608,  ..., -0.1686, -0.1922, -0.3333]]]]), 'labels': [0, 0]}

But when I try to access the labels directly


I got a key error message:

KeyError                                  Traceback (most recent call last) Cell In[32], line 1
----> 1 prepared_ds['train']['labels']

File ~/anaconda3/envs/LLM/lib/python3.12/site-packages/datasets/, in Dataset.__getitem__(self, key)    2870 def __getitem__(self, key): 
# noqa: F811    2871     """Can be used to index columns (by string names) or rows (by integer index or iterable of indices or bools)."""
-> 2872     return self._getitem(key)

File ~/anaconda3/envs/LLM/lib/python3.12/site-packages/datasets/, in Dataset._getitem(self, key, **kwargs)    2855 formatter = get_formatter(format_type, features=self._info.features,
**format_kwargs)    2856 pa_subtable = query_table(self._data, key, indices=self._indices)
-> 2857 formatted_output = format_table(    2858     pa_subtable, key, formatter=formatter, format_columns=format_columns, output_all_columns=output_all_columns    2859 )    2860 return formatted_output

File ~/anaconda3/envs/LLM/lib/python3.12/site-packages/datasets/formatting/, in format_table(table, key, formatter, format_columns, output_all_columns)
    637 python_formatter = PythonFormatter(features=formatter.features)
    638 if format_columns is None:
--> 639     return formatter(pa_table, query_type=query_type)
    640 elif query_type == "column":
    641     if key in format_columns:

File ~/anaconda3/envs/LLM/lib/python3.12/site-packages/datasets/formatting/, in Formatter.__call__(self, pa_table, query_type)
    403     return self.format_row(pa_table)
    404 elif query_type == "column":
--> 405     return self.format_column(pa_table)
    406 elif query_type == "batch":
    407     return self.format_batch(pa_table)

File ~/anaconda3/envs/LLM/lib/python3.12/site-packages/datasets/formatting/, in CustomFormatter.format_column(self, pa_table)
    500 def format_column(self, pa_table: pa.Table) -> ColumnFormat:
--> 501     formatted_batch = self.format_batch(pa_table)
    502     if hasattr(formatted_batch, "keys"):
    503         if len(formatted_batch.keys()) > 1:

File ~/anaconda3/envs/LLM/lib/python3.12/site-packages/datasets/formatting/, in CustomFormatter.format_batch(self, pa_table)
    520 batch = self.python_arrow_extractor().extract_batch(pa_table)
    521 batch = self.python_features_decoder.decode_batch(batch)
--> 522 return self.transform(batch)

Cell In[12], line 5, in transform(example_batch)
      3 def transform(example_batch):
      4     # Take a list of PIL images and turn them to pixel values
----> 5     inputs = feature_extractor([x for x in example_batch['image']], return_tensors='pt')
      7     # Don't forget to include the labels!
      8     inputs['labels'] = example_batch['labels']

KeyError: 'image'

It sounds like the error is because the feature extractor added 'pixel_values' but the feature is kept as 'image' But it also appears to imply an attempt to re-apply transform...

Also: it is not possible to save the dataset to the disk

TypeError                                 Traceback (most recent call last) Cell In[21], line 1
----> 1 dataset.save_to_disk(img_path)

File ~/anaconda3/envs/LLM/lib/python3.13/site-packages/datasets/, in Dataset.save_to_disk(self, dataset_path, max_shard_size, num_shards, num_proc, storage_options)    1501         json.dumps(state["_format_kwargs"][k])    1502     except TypeError as e:
-> 1503         raise TypeError(    1504             str(e) + f"\nThe format kwargs must be JSON serializable, but key '{k}' isn't."    1505 ) from None    1506 # Get json serializable dataset info    1507 dataset_info = asdict(self._info)

TypeError: Object of type function is not JSON serializable The format kwargs must be JSON serializable, but key 'transform' isn't.

Note the original codes in that notebook work perfectly (training, evaluation, etc). I just got this error because I tried to inspect the dataset, try to save the generated dataset, etc. to explore the dataset object...

Shouldn't the dataset structure be accessible in a similar way after with_transform() or set_transform()? Why does it call the transform function again if we just attempt to access one of the features?

I’m hoping you can shed some light on this behaviour...


  • This is not the way how you pick up the dataset items. First you need to indicate the slice:

    prepared_ds_batch = prepared_ds['train'][0:10]

    by using indexing.

    Then you can use the key labels

    [out]: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

    Regarding the second issue with saving the data: you are not able to save it because of the known issue with transform functions:

    You might however save the dataset as prepared_ds.with_format(None).save_to_disk('test_path'). But after loading it again from disk you need to launch again the transform function.

    Edited: You cannot use prepared_ds['train']['labels'] as 'labels' is expected to be integers representing indices of the items.