Search code examples
pythontensorflowloss-function

What's the correct way to use tf.while_loop in a custom loss function?


I intended to use the following functions as the loss for my training:

import tensorflow as tf

def wrap(dist): 
    return tf.while_loop(
        cond=lambda X: tf.math.abs(X) > 0.5,
        body=lambda X: tf.math.subtract(X, 1.0),
        loop_vars=(dist))


# PBC-aware MSE, period = 1.0 ([0, 1.0])
def custom_loss(y_true, y_pred):
    diff = tf.math.abs(y_true - y_pred)
    diff = tf.nest.flatten(diff)
    diff = tf.vectorized_map(wrap, diff)
    return tf.math.reduce_mean(tf.math.square(diff))

# ...other code for loading data and defining the model

model.compile(optimizer=tf.keras.optimizers.SGD(momentum=0.1),
              loss=custom_loss)

but the I encountered a bunch of error messages. As the log is too long I put them in a gist: https://gist.github.com/HanatoK/f75fddd82372f499c37279f1128cad7a

The equivalent numpy version of the code above should be

def wrap_diff2(x, y, period=1.0):
    diff = np.abs(x - y)
    while diff > 0.5 * period:
        diff -= period
    return diff * diff

def custom_loss_numpy(y_true, y_pred):
    diff2 = np.vectorize(wrap_diff2)(y_true, y_pred)
    return np.mean(diff2)

Any ideas? The full code example is shared on google colab: https://colab.research.google.com/drive/1ExVHgyKHQfGcpXvo5ZsuBBmzmHzxUekC?usp=sharing


Solution

  • Try this:

    import tensorflow as tf
    import numpy as np
    
    def wrap(dist): 
        return tf.while_loop(
            cond=lambda X: tf.math.abs(X) > 0.5,
            body=lambda X: tf.math.subtract(X, 1.0),
            loop_vars=(dist))
    
    def custom_loss(y_true, y_pred):
        diff = tf.math.abs(y_true - y_pred)
        diff = tf.reshape(diff, [-1])
        diff = tf.vectorized_map(wrap, [diff])
        return tf.math.reduce_mean(tf.math.square(diff))
    
    y_true = np.array([[0., 1., 1.0], [0., 0., 0.]])
    y_pred = np.array([[1., 1., 1.0], [1., 0., 1.]])
    custom_loss(y_true, y_pred).numpy()