Search code examples
pythontensorflowmachine-learningtensorflow-datasets

How can I remove a tensor from a FilterDataset/MapDataset


I have a dataset of video_id, user_id and score tensors. I want to filter this to only positive examples with score above a threshold and then remove the score tensor.

def decode_retrieval_positive(record_bytes):
    return tf.io.parse_single_example(
        # Data
        record_bytes,
        # Schema
        {"video_id": tf.io.FixedLenFeature([], dtype=tf.int64),
        "user_id": tf.io.FixedLenFeature([], dtype=tf.int64),
        "score": tf.io.FixedLenFeature([], dtype=tf.float32)}
    )

ratings_positive = ratings.map(
            decode_retrieval_positive
        ).filter(
            lambda x: x["score"] > 0.2
        ).map(
            lambda x: {"video_id": x["video_id"], "user_id": x["user_id"]}
        )

<MapDataset element_spec={'video_id': TensorSpec(shape=(), dtype=tf.int64, name=None), 'user_id': TensorSpec(shape=(), dtype=tf.int64, name=None)}>

This gives me this error:

2022-02-07 08:18:52.825318: W tensorflow/core/framework/op_kernel.cc:1745] OP_REQUIRES failed at example_parsing_ops.cc:94 : INVALID_ARGUMENT: Feature: score (data type: float) is required but could not be found.

One solution would be to simply make a new positive_ratings.tfrecord but that would take up more space and I'm annoyed I can't do this.


Solution

  • You just need to make sure x['score'] has float values. Here is a working example:

    import tensorflow as tf
    tf.random.set_seed(111)
    
    # Create dummy data
    dataset = tf.data.Dataset.range(10)
    dataset = dataset.map(lambda x: {'video_id': x, 'user_id': 2555 + x, 'score': tf.cast(x, dtype=tf.float32)*tf.random.normal(())})
    
    dataset = dataset.filter(lambda x: x["score"] > 0.2)
    dataset = dataset.map(lambda x: {"video_id": x["video_id"], "user_id": x["user_id"]})
    for d in dataset:
      print(d)
    
    {'video_id': <tf.Tensor: shape=(), dtype=int64, numpy=5>, 'user_id': <tf.Tensor: shape=(), dtype=int64, numpy=2560>}