Search code examples
pythontensorflowtensorflow-datasets

Find the max value across all features in tensorflow


Consider the following code below:

import tensorflow as tf

input_slice=3
labels_slice=2

def split_window(x):  
    inputs = tf.slice(x,[0], [input_slice])
    labels = tf.slice(x,[input_slice], [labels_slice]) 
    return inputs, labels

dataset = tf.data.Dataset.range(1, 25 + 1).batch(5).map(split_window)

for i, j in dataset:
    print(i.numpy(),end="->")
    print(j.numpy())

This code will give me the output:

[1 2 3]->[4 5]
[6 7 8]->[ 9 10]
[11 12 13]->[14 15]
[16 17 18]->[19 20]
[21 22 23]->[24 25]

Every row in the tensor j represents a feature. I want to find the max value across all features. In this case, it would be 25. How would I find the max value across all features?


Solution

  • One solution to your problem would be to use tf.TensorArray and tf.reduce_max:

    import tensorflow as tf
    
    input_slice=3
    labels_slice=2
    
    def split_window(x):  
        inputs = tf.slice(x,[0], [input_slice])
        labels = tf.slice(x,[input_slice], [labels_slice]) 
        return inputs, labels
    
    dataset = tf.data.Dataset.range(1, 25 + 1).batch(5).map(split_window)
    
    ta = tf.TensorArray(tf.int64, size=0, dynamic_size=True)
    
    for i, j in dataset:
        print(i.shape, i.numpy(),end="->")
        print(j.numpy())
        ta.write(ta.size(), j)
        
    max_value = tf.reduce_max(ta.stack(), axis=(0, 1)).numpy()
    
    print(max_value)
    # 25
    

    With tf.reduce_max you are getting the max value across the dimensions 0 and 1 and reducing your tensor. Feel free to give some feedback, if I did not understand the question correctly.