Search code examples
pythontensorflowdatasettensorflow-datasets

How to create a dataset for tensorflow from a txt file containing paths and labels?


I'm trying to load the DomainNet dataset into a tensorflow dataset. Each of the domains contain two .txt files for the training and test data respectively, which is structured as follows:

painting/aircraft_carrier/painting_001_000106.jpg 0
painting/aircraft_carrier/painting_001_000060.jpg 0
painting/aircraft_carrier/painting_001_000130.jpg 0
painting/aircraft_carrier/painting_001_000058.jpg 0
painting/aircraft_carrier/painting_001_000093.jpg 0
painting/aircraft_carrier/painting_001_000107.jpg 0
painting/aircraft_carrier/painting_001_000088.jpg 0
painting/aircraft_carrier/painting_001_000014.jpg 0
painting/aircraft_carrier/painting_001_000013.jpg 0
...

Which is one line per image containing a relative path and a label. My question is, if there is already some built-in way in tensorflow/keras to load this kind of structure, or if I have to parse and load the data manually? So far my google-fu let me down...


Solution

  • You can use tf.data.TextLineDataset to load and process multiple txt files at a time:

    import tensorflow as tf
    import matplotlib.pyplot as plt
    
    with open('data.txt', 'w') as f:
      f.write('/content/result_image1.png 0\n')
      f.write('/content/result_image2.png 1\n')
    
    with open('more_data.txt', 'w') as f:
      f.write('/content/result_image1.png 1\n')
      f.write('/content/result_image2.png 0\n')
    
    dataset = tf.data.TextLineDataset(['/content/data.txt', '/content/more_data.txt'])
    for element in dataset.as_numpy_iterator():
      print(element)
    
    b'/content/result_image1.png 0'
    b'/content/result_image2.png 1'
    b'/content/result_image1.png 1'
    b'/content/result_image2.png 0'
    

    Process data:

    def process(x):
      splits = tf.strings.split(x, sep=' ')
      image_path, label = splits[0], splits[1]
      img = tf.io.read_file(image_path)
      img = tf.io.decode_png(img, channels=3)
      return  img, tf.strings.to_number(label, out_type=tf.int32)
    
    dataset = dataset.map(process)
    for x, y in dataset.take(1):
      print('Label -->', y)
      plt.imshow(x.numpy())
    
    Label --> tf.Tensor(0, shape=(), dtype=int32)
    

    enter image description here