Search code examples
pythonnumpytensorflowtensorflow-datasetstf.data.dataset

How to change the values of a tf.Dataset object in a specific index


The structure of my tf.data.Dataset object is as follow. ((3, 400, 1), (3, 400, 1))

I would like to divide the elements in the 3rd row, of each element by 10. My code is as follows. But it complains as NumPy arrays are immutable (I'd like to use map )

def alternate_row (dataset):
  xx, yy = [], [] 
  for x, y in dataset.as_numpy_iterator():
    x[2] /= 10
    y[2] /= 10
    xx.append(x)
    yy.append(y)

  return xx, yy

Solution

  • Try using tf.data.Dataset.map and tf.concat:

    import tensorflow as tf
    
    samples = 5
    x1 = tf.random.normal((samples, 3, 400, 1))
    x2 = tf.random.normal((samples, 3, 400, 1))
    
    dataset = tf.data.Dataset.from_tensor_slices((x1, x2))
    
    def divide(x1, x2):
      x1 = tf.concat([x1[:2], x1[2:] / 10], axis=0)
      x2 = tf.concat([x2[:2], x2[2:] / 10], axis=0)
      return x1, x2
    
    dataset = dataset.map(divide)
    

    Note that I assume you want to change the values in the second dimension of the tensors, but you can change the notation for the slice to suit your needs.