Search code examples
pythontensorflowtf.data.dataset

how iterate inside a tf.data.dataset for stacking the features


starting with this code who is working properly.

# loading csv
dataset = tf.data.experimental.make_csv_dataset(
        file_pattern=filename,
        num_parallel_reads=2,
        batch_size=128,
        num_epochs=1,
        label_name='streamflow',
        select_columns=keep_columns,
        shuffle_buffer_size=10000,
        header=True,
        field_delim=','
    )

def preprocess_fn(features, label):
        # Normalize the features (example: scaling to [0, 1])
        features['total_precipitation_sum'] /= 100.0
        features['temperature_2m_min'] /= 100.0
        features['temperature_2m_max'] /= 100.0
        features['snow_depth_water_equivalent_max'] /= 100.0

----->  # Create a 'main_inputs' feature by stacking the selected columns
        features['main_inputs'] = tf.stack([
                           features['total_precipitation_sum'],
                           features['temperature_2m_min'],
                           features['temperature_2m_max'],
                           features['snow_depth_water_equivalent_max']
                                 ], axis=-1)
   
        return {'main_inputs': features['main_inputs']}, label
    
dataset = dataset.map(preprocess_fn)

In the "Create a 'main_inputs' feature by stacking the selected columns" section:

the "features" name are hard coded but will change in some cases. How I can automated the detection of all features and stacking process?


Solution

  • Instead of hard-coding the feature names we can iterate through them using list-comprehension and passing that to tf.stack:

    features['main_inputs'] = tf.stack([features[_feature] for _feature, _ in features.items()], axis=-1)
    

    or just iterate through features's values:

    features['main_inputs'] = tf.stack([_feature for _, _feature in features.items()], axis=-1)