Search code examples
pythonarraystensorflowtensorflow2.0loss-function

Tensorflow 2.x – Tensor with Average of Surrounding Cells


I'm trying to write a custom loss function in Tensorflow 2.x that encourages a gradual gradient in the output space (a 2D matrix). So, as one component of the loss function, I want to take in a Tensor and return a Tensor in which each cell represents the average of the corresponding neighbor cells in the original tensor.

a harmonic matrix transformation

For example, take the upper left cell: 6.3 = (7 + 9 + 3)/3. Or take the middle cell: 4.5 = (1 + 3 + 5 + 7 + 8 + 6 + 4 + 2)/8.

Consider the following code:

def gradient_encouraging_loss(y_true: Tensor, y_pred: Tensor) -> Tensor:
    gradient_loss: Tensor = tf.divide(tf.reduce_sum(tf.abs(tf.subtract(
        y_pred, tensor_harmonic(y_pred)
    ))), tf.cast(tf.size(y_pred), tf.float32))

    return gradient_loss

How would I implement tensor_harmonic()? y_pred has a shape of (None, X, Y), where X and Y are the output matrix dimensions.


Solution

  • You can do that with a 2D convolution operation for the most part, but then you need to take extra care with the outer values. Here is how you could do it:

    import tensorflow as tf
    
    def surround_average(x):
        x = tf.convert_to_tensor(x)
        dt = x.dtype
        # Compute surround sum
        filter = tf.constant([[1, 1, 1], [1, 0, 1], [1, 1, 1]], dtype=dt)
        x2 = x[tf.newaxis, :, :, tf.newaxis]
        filter2 = filter[:, :, tf.newaxis, tf.newaxis]
        y2 = tf.nn.conv2d(x2, filter2, strides=1, padding='SAME')
        y = y2[0, :, :, 0]
        # Make matrix of number of surrounding elements
        s = tf.shape(x)
        d = tf.fill(s - 2, tf.constant(8, dtype=dt))
        d = tf.pad(d, [[0, 0], [1, 1]], constant_values=5)
        top_row = tf.concat([[3], tf.fill([s[1] - 2], tf.constant(5, dtype=dt)), [3]], axis=0)
        d = tf.concat([[top_row], d, [top_row]], axis=0)
        # Return average
        return y / d
    
    # Test
    x = tf.reshape(tf.range(24.), (4, 6))
    print(x.numpy())
    # [[ 0.  1.  2.  3.  4.  5.]
    #  [ 6.  7.  8.  9. 10. 11.]
    #  [12. 13. 14. 15. 16. 17.]
    #  [18. 19. 20. 21. 22. 23.]]
    print(surround_average(x).numpy())
    # [[ 4.6666665  4.6        5.6        6.6        7.6        8.333333 ]
    #  [ 6.6        7.         8.         9.        10.        10.4      ]
    #  [12.6       13.        14.        15.        16.        16.4      ]
    #  [14.666667  15.4       16.4       17.4       18.4       18.333334 ]]
    

    EDIT: The code above can be adapted to work with batches of matrices with a few minor changes:

    import tensorflow as tf
    
    def surround_average_batch(x):
        x = tf.convert_to_tensor(x)
        dt = x.dtype
        # Compute surround sum
        filter = tf.constant([[1, 1, 1], [1, 0, 1], [1, 1, 1]], dtype=dt)
        x2 = tf.expand_dims(x, axis=-1)
        filter2 = filter[:, :, tf.newaxis, tf.newaxis]
        y2 = tf.nn.conv2d(x2, filter2, strides=1, padding='SAME')
        y = tf.squeeze(y2, axis=-1)
        # Make matrix of number of surrounding elements
        s = tf.shape(x)
        d = tf.fill(s[1:] - 2, tf.constant(8, dtype=dt))
        d = tf.pad(d, [[0, 0], [1, 1]], constant_values=5)
        top_row = tf.concat([[3], tf.fill([s[2] - 2], tf.constant(5, dtype=dt)), [3]], axis=0)
        d = tf.concat([[top_row], d, [top_row]], axis=0)
        # Return average
        return y / d
    
    # Test
    x = tf.reshape(tf.range(24.), (2, 4, 3))
    print(x.numpy())
    # [[[ 0.  1.  2.]
    #   [ 3.  4.  5.]
    #   [ 6.  7.  8.]
    #   [ 9. 10. 11.]]
    # 
    #  [[12. 13. 14.]
    #   [15. 16. 17.]
    #   [18. 19. 20.]
    #   [21. 22. 23.]]]
    print(surround_average_batch(x).numpy())
    # [[[ 2.6666667  2.8        3.3333333]
    #   [ 3.6        4.         4.4      ]
    #   [ 6.6        7.         7.4      ]
    #   [ 7.6666665  8.2        8.333333 ]]
    # 
    #  [[14.666667  14.8       15.333333 ]
    #   [15.6       16.        16.4      ]
    #   [18.6       19.        19.4      ]
    #   [19.666666  20.2       20.333334 ]]]