Search code examples

InvalidArgumentError: Inner dimensions of output shape must match inner dimensions of updates shape

I'm trying to implement an SPL loss in keras. All I need to do is pretty simple, I'll write it in numpy to explain what I need:

def spl_loss(y_true, y_pred, lmda):
    # compute any arbitrary loss function
    L = categorical_cross_entropy(y_true, y_pred)
    # set to zero those values with an error greater than lambda
    L[L>lmda] = 0
    return L

I'm trying to implement it following this tutorial but I'm having troubles with the step needed to set values to zero.

Currently I have the following code:

def spl_loss(lmda, loss_fn):
    def loss(y_true, y_pred):
         # compute an arbitrary loss function, L
        loss_value = loss_fn(y_true, y_pred) # tensor of shape (64,)
        # get the mask of L greater than lmda
        mask = tf.greater( loss_value, tf.constant( float(lmda) ) )    # tensor of shape (64,)
        # compute indexes for the mask
        indexes = tf.reshape(tf.where(mask), [-1])  # tensor of shape (n,); where n<=64
        # set to zero values on indexes
        spl_loss_value = tf.tensor_scatter_nd_update(loss_value, indexes, tf.zeros_like(loss_value, dtype=loss_value.dtype) )  # this line gives the error
        return spl_loss_value
    return loss

According to the docs, tensor_scatter_nd_update operation should perform the assignment operation, but it fails with the following error:

    spl_loss_value = tf.tensor_scatter_nd_update(loss_value, indexes, tf.zeros_like(loss_value, dtype=loss_value.dtype) )
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/util/ wrapper  **
        return target(*args, **kwargs)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/ tensor_scatter_nd_update
        tensor=tensor, indices=indices, updates=updates, name=name)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/ tensor_scatter_update
        _ops.raise_from_not_ok_status(e, name)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ raise_from_not_ok_status
        six.raise_from(core._status_to_exception(e.code, message), None)
    <string>:3 raise_from

    InvalidArgumentError: Inner dimensions of output shape must match inner dimensions of updates shape. Output: [64] updates: [64] [Op:TensorScatterUpdate]

I'm running it in colab, here you can try it.

I tried several re-shapes, because I understand it is a matter of shapes expected vs obtained, but I don't find the way. What's going on here?

Thanks in advance


  • The reason you're getting this error is that the indices in tf.tensor_scatter_nd_update requires at least two axes, or tf.rank(indices) > = 2 need to be fullfilled. The reason for indices in 2D (in scaler update) is to hold two information, one is the length of the updates (num_updates) and the length of the index vector. For a detailed overview of this, check the following answer regarding this: Tensorflow 2 - what is 'index depth' in tensor_scatter_nd_update?.

    Here is the correct implementation of SPL loss in .

    def spl_loss(lmda):
        def loss(y_true, y_pred):
             # compute an arbitrary loss function, L
            loss_value = keras.losses.sparse_categorical_crossentropy(y_true, y_pred)
            # get the mask of L greater than lmda
            mask = tf.greater( loss_value, tf.constant(float(lmda) ) )    
            # compute indexes for the mask
            indexes = tf.where(mask) # tensor of shape (n,); where n<=64
            updates = tf.reshape(tf.zeros_like(indexes, dtype=tf.float32), [-1])
            # scaler update check
            num_updates, index_depth = indexes.shape.as_list()
            assert updates.shape == [num_updates]
            assert index_depth == tf.rank(loss_value)
            # print()
            # print('A', tf.reshape(tf.where(mask), [-1])[:10].numpy()) 
            # print('B', tf.where(mask).numpy()[:10]) 
            # print('Ranks: ', tf.rank(loss_value).numpy(), 
            #                  tf.rank(indices).numpy(), 
            #                   tf.rank(updates).numpy())
            # print('Shape: ', loss_value.shape, indexes.shape, updates.shape)
            # set to zero values on indexes
            spl_loss_value = tf.tensor_scatter_nd_update(loss_value, indexes, updates )
            return spl_loss_value
        return loss
    model.compile(optimizer="adam", loss=spl_loss(lmda=2.), run_eagerly=True)

    Ref: tf.tensor_scatter_nd_update