Search code examples
tfx

TFX - How to inspect records from CsvExampleGen


Question

How to inspect the data loaded into TFX CsvExampleGen?

CSV

Top 3 rows from the california_housing_train.csv looks below.

longitude latitude housing_median_age total_rooms total_bedrooms population households median_income median_house_value
-122.05 37.37 27 3885 661 1537 606 6.6085 344700
-118.3 34.26 43 1510 310 809 277 3.599 176500
-117.81 33.78 27 3589 507 1484 495 5.7934 270500

CsvExampleGen

The CSV is loaded into CsvExampleGen. In my understanding, XXXExampleGen is to generate tf.Record instances, hence I wonder if there is a way to iterate through the records from CsvExampleGen.

from tfx.components import (
    CsvExampleGen
)
housing = CsvExampleGen("sample_data/california_housing_train.csv")
housing
----------
CsvExampleGen(
    spec: <tfx.types.standard_component_specs.FileBasedExampleGenSpec object at 0x7fcd90435450>,
    executor_spec: <tfx.dsl.components.base.executor_spec.BeamExecutorSpec object at 0x7fcd90435850>,
    driver_class: <class 'tfx.components.example_gen.driver.FileBasedDriver'>,
    component_id: CsvExampleGen,
    inputs: {},
    outputs: {
        'examples': OutputChannel(artifact_type=Examples,
        producer_component_id=CsvExampleGen,
        output_key=examples,
        additional_properties={},
        additional_custom_properties={})
    }
)

Experiment

for record in housing.outputs['examples']:
    print(record)

TypeError Traceback (most recent call last) in ----> 1 for record in housing.outputs['examples']: 2 print(record)

TypeError: 'OutputChannel' object is not iterable


Solution

  • Have you got a chance to take a look at this section in tutorials, which explains how to display the artifacts of ExampleGen component? You can modify the code below (Source: TFX Tutorial) to achieve the same.

    # Get the URI of the output artifact representing the training examples, which is a directory
    train_uri = os.path.join(example_gen.outputs['examples'].get()[0].uri, 'Split-train')
    
    # Get the list of files in this directory (all compressed TFRecord files)
    tfrecord_filenames = [os.path.join(train_uri, name)
                          for name in os.listdir(train_uri)]
    
    # Create a `TFRecordDataset` to read these files
    dataset = tf.data.TFRecordDataset(tfrecord_filenames, compression_type="GZIP")
    
    # Iterate over the first 3 records and decode them.
    for tfrecord in dataset.take(3):
      serialized_example = tfrecord.numpy()
      example = tf.train.Example()
      example.ParseFromString(serialized_example)
      pp.pprint(example)