Search code examples
tensorflow2.0tfxtensorflow-data-validation

Why isn't SchemaGen supported in tfdv.display_schema()?


Regarding TFX' tensorflow-data-validation, I'm trying to understand when I should use *Gen components vs. using TFDV provided methods.

Specifically, what's confusing me is that I have this as my ExampleGen:

output = example_gen_pb2.Output(
         split_config=example_gen_pb2.SplitConfig(splits=[
             example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=7),
             example_gen_pb2.SplitConfig.Split(name='test', hash_buckets=2),
             example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1)
         ]))
example_gen = CsvExampleGen(input_base=os.path.join(base_dir, data_dir), 
output_config=output)
context.run(example_gen)

So I figured, I'd want to generate my statistics from my train split, rather than from the original train file, so I tried with:

statistics_gen = StatisticsGen(
  examples=example_gen.outputs['examples'],
  exclude_splits=['eval']
)
context.run(statistics_gen)

and that runs fine. But then, I tried inferring my schema (insert buzzer sound):

schema = tfdv.infer_schema(statistics=statistics_gen)

and knowingly this raises the error below. I fully expected that it wasn't the correct type but I cannot figure out how to extract from the StatsGen object the proper output to feed to the infer_schema() method.

Alternatively, if I pursue a solely *Gen-based component structure, it builds, but I don't see how to properly visualize the schema, stats, etc. Finally, the reason I'm using the tfdv.infer_schema() call here is for the similarly ill-fated "display_schema()" call that errors if you try passing it a SchemaGen.

Error from above:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-11-93ceafbcb04a> in <module>
----> 1 schema = tfdv.infer_schema(statistics=validate_stats)
      2 tfdv.write_schema_text(schema, schema_location)
      3 
      4 tfdv.display(infer_schema)

/usr/local/lib/python3.6/dist-packages/tensorflow_data_validation/api/validation_api.py in infer_schema(statistics, infer_feature_shape, max_string_domain_size, schema_transformations)
     95     raise TypeError(
     96         'statistics is of type %s, should be '
---> 97         'a DatasetFeatureStatisticsList proto.' % type(statistics).__name__)
     98 
     99   # This will raise an exception if there are multiple datasets, none of which

TypeError: statistics is of type ExampleValidator, should be a DatasetFeatureStatisticsList proto.

What I'm really trying to understand is why do we have components, such as SchemaGen and StatisticsGen only to have TFDV require we use the internal functions in order to get value from this. I'm assuming its providing for the interactive pipeline vs. non-interactive scenarios but my Googling has left me unclear.

If there is a way to generate and view stats based on a split of my data rather than relying on the file reader, I'd love to know that also. (In case it's not obvious, yes, I'm new to TFX).

TIA


Solution

  • I'm also new to TFX. Your post about the ExampleValidator helped me out, hopefully this answers your question.

    Using components only to visualize schema

     statistics_gen = StatisticsGen(
      examples=example_gen.outputs['examples'],
      exclude_splits=['eval']
    )
    context.run(statistics_gen)
    
    schema_gen = SchemaGen(
        statistics=statistics_gen.outputs['statistics'],
        infer_feature_shape=True
    )
    context.run(schema_gen)
    
    context.show(schema_gen.outputs['schema']) # this should allow you to to visualize your schema 
    

    Using components + TFDV to visualize schema

    It looks like we can't use the StatisticsGen directly. We'll need to know the location of where the statistics gen artifact is being saved to and then load that artifact using tfdv.load_statistics

    # get the stats artifact
    stats_artifact = statistics_gen.outputs.statistics._artifacts[0]
    
    # get base path 
    base_path = stats_artifact.uri 
    
    # get path to file 
    train_stats_file = os.path.join(base_path, 'train/stats_tfrecord') #only showing training as an example
    
    # load stats 
    loaded_stats = tfdv.load_statistics(train_stats_file)
    
    # generic and show schema
    schema = tfdv.infer_schema(loaded_stats)
    
    tfdv.display_schema(schema)