Search code examples
pythoncsvtensorflowtensorflow2.0tensorflow-datasets

csv table row as label for previous several rows


I have a question about tensorflow. I have csv data like image attached, and I want to map it: green row - is label for previous 5 rows. Is it possible to do it inside map function (dataset.map()) ? And how ?

enter image description here


Solution

  • Try tf.data.Dataset.window:

    import tensorflow as tf
    import pandas as pd
    
    d = {'A': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 
         'B': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 
         'C': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 
         'D': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 
         'E': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 
         'F': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 
         'G': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 
         'H': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]}
    
    df = pd.DataFrame(data=d)
    
    def redefine_data(windowed_ds):
      data, labels = [], []
      for window in windowed_ds:
        data.append(tf.convert_to_tensor([w for w in window.take(5)]))
        labels.append(next(iter(window.skip(5).take(1))))
      return tf.data.Dataset.from_tensor_slices((data, labels))
    
    ds = tf.data.Dataset.from_tensor_slices((df.values)).window(6, shift=3, stride=1, drop_remainder=True)
    ds = redefine_data(ds)
    for data, label in ds:
      print(data, label)
    
    tf.Tensor(
    [[1 1 1 1 1 1 1 1]
     [2 2 2 2 2 2 2 2]
     [3 3 3 3 3 3 3 3]
     [4 4 4 4 4 4 4 4]
     [5 5 5 5 5 5 5 5]], shape=(5, 8), dtype=int64) tf.Tensor([6 6 6 6 6 6 6 6], shape=(8,), dtype=int64)
    tf.Tensor(
    [[4 4 4 4 4 4 4 4]
     [5 5 5 5 5 5 5 5]
     [6 6 6 6 6 6 6 6]
     [7 7 7 7 7 7 7 7]
     [8 8 8 8 8 8 8 8]], shape=(5, 8), dtype=int64) tf.Tensor([9 9 9 9 9 9 9 9], shape=(8,), dtype=int64)
    tf.Tensor(
    [[ 7  7  7  7  7  7  7  7]
     [ 8  8  8  8  8  8  8  8]
     [ 9  9  9  9  9  9  9  9]
     [10 10 10 10 10 10 10 10]
     [11 11 11 11 11 11 11 11]], shape=(5, 8), dtype=int64) tf.Tensor([12 12 12 12 12 12 12 12], shape=(8,), dtype=int64)