Search code examples
tensorflowtfx

Why ImportExampleGen reads TFRecords as SparseTensor instead of Tensor?


I'm converting a CSV file into a TFRecords file like this:

File: ./dataset/csv/file.csv

feature_1, feture_2, output
1, 1, 1
2, 2, 2
3, 3, 3
import tensorflow as tf
import csv
import os

print(tf.__version__)

def create_csv_iterator(csv_file_path, skip_header):
    
    with tf.io.gfile.GFile(csv_file_path) as csv_file:
        reader = csv.reader(csv_file)
        if skip_header: # Skip the header
            next(reader)
        for row in reader:
            yield row

def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def create_example(row):
    """
    Returns a tensorflow.Example Protocol Buffer object.
    """
    features = {}

    for feature_index, feature_name in enumerate(["feature_1", "feture_2", "output"]):
        feature_value = row[feature_index]
        features[feature_name] = _int64_feature(int(feature_value))

    return tf.train.Example(features=tf.train.Features(feature=features))

def create_tfrecords_file(input_csv_file):
    """
    Creates a TFRecords file for the given input data
    """
    output_tfrecord_file = input_csv_file.replace("csv", "tfrecords")
    writer = tf.io.TFRecordWriter(output_tfrecord_file)
    
    print("Creating TFRecords file at", output_tfrecord_file, "...")
    
    for i, row in enumerate(create_csv_iterator(input_csv_file, skip_header=True)):
        
        if len(row) == 0:
            continue
            
        example = create_example(row)
        content = example.SerializeToString()
        writer.write(content)
        
    writer.close()
    
    print("Finish Writing", output_tfrecord_file)
create_tfrecords_file("./dataset/csv/file.csv")

Then I'll read the generated TFRecords files using ImportExampleGen class:

import os

import absl
import tensorflow_model_analysis as tfma
tf.get_logger().propagate = False

from tfx import v1 as tfx
from tfx.orchestration.experimental.interactive.interactive_context import InteractiveContext

%load_ext tfx.orchestration.experimental.interactive.notebook_extensions.skip
context = InteractiveContext()
example_gen = tfx.components.ImportExampleGen(input_base="./dataset/tfrecords")
context.run(example_gen, enable_cache=True)
statistics_gen = tfx.components.StatisticsGen(
    examples=example_gen.outputs['examples'])
context.run(statistics_gen, enable_cache=True)
schema_gen = tfx.components.SchemaGen(
    statistics=statistics_gen.outputs['statistics'],
    infer_feature_shape=False)
context.run(schema_gen, enable_cache=True)

File: ./transform.py

def preprocessing_fn(inputs):
  """tf.transform's callback function for preprocessing inputs.
  Args:
    inputs: map from feature keys to raw not-yet-transformed features.
  Returns:
    Map from string feature key to transformed feature operations.
  """

  print(inputs)

  return inputs
transform = tfx.components.Transform(
    examples=example_gen.outputs['examples'],
    schema=schema_gen.outputs['schema'],
    module_file=os.path.abspath("./transform.py"))
context.run(transform, enable_cache=True)

In the preprocessing_fn function shows that inputs is a SparseTensor objects. My question is why? As far as I can tell, my dataset's samples are dense and they should be Tensor instead. Am I doing something wrong?


Solution

  • For anyone else who might be struggling with the same issue, I found the culprit. It's the SchemaGen class. This is how I was instantiating its object:

    schema_gen = tfx.components.SchemaGen(
        statistics=statistics_gen.outputs['statistics'],
        infer_feature_shape=False)
    

    I don't know what's the use case for asking SchemaGen class not to infer the shape of the features but the tutorial I was following had it set to False and I had just copied and pasted the same thing. Comparing with some other tutorials, I realized that it could be the reason why I was getting SparseTensor.

    So, if you let SchemaGen infer the shape of your features or you load a hand crafted schema in which you've set the shapes yourself, you'll be getting a Tensor in your preprocessing_fn. But if the shapes are not set, the features will be instances of SparseTensor.

    For the sake of completeness, this is the fixed snippet:

    schema_gen = tfx.components.SchemaGen(
        statistics=statistics_gen.outputs['statistics'],
        infer_feature_shape=True)