Search code examples
csvtensorflowtime-seriesdataset

Using tf.data.experimental.make_csv_dataset for time series data


How do I use tf.data.experimental.make_csv_dataset with CSV files containing time series data?

building_dataset = tf.data.experimental.make_csv_dataset(file_pattern=csv_file,
                                                        batch_size=5,num_epochs=1, shuffle=False,select_columns=feature_columns)

Solution

  • It is assumed that the CSV file is already sorted w.r.t. time. First, read the CSV file using:

    building_dataset = tf.data.experimental.make_csv_dataset(file_pattern=csv_file,
                                                            batch_size=5,num_epochs=1, shuffle=False,select_columns=feature_columns)
    

    Then define a pack_features_vector to convert to a features vector and unbatch using flat_map(). The tensors are also cast to float32.

    def pack_features_vector(features):
        """Pack the features into a single array."""
        
        features = tf.stack([tf.cast(x,tf.float32) for x in list(features.values())], axis=1)
        return features
    
       
    building_dataset = building_dataset.map(pack_features_vector)
    building_dataset = building_dataset.flat_map(lambda x: tf.data.Dataset.from_tensor_slices(x))
    for feature in building_dataset.take(1):
        print('Stacked tensor:',feature)
    

    Then use the window and flat map method.

    building_dataset = building_dataset.window(window_size, shift=1, drop_remainder=True)
    building_dataset = building_dataset.flat_map(lambda window: window.batch(window_size))
    

    Then use map method to separate features and labels.

    building_dataset = building_dataset.map(lambda window: (window[:,:-1], window[-1:,-1]))
    for feature, label in building_dataset.take(5):
        print(feature.shape)
        print('feature:',feature[:,0:4])
        print('label:',label)
    

    Finally create batches using batch() and use as inputs to model training.

    building_dataset = building_dataset.batch(32)