Search code examples
pythontensorflowkerasloss-function

Tensorflow tensor loses dimension for some reason


I have a custom loss function that is reporting an error before any real processing happens.

I have a y_train of dimension (2717, 5, 5, 6) and a batch size of 25 with constants S1=S2=5. All I do is tf.reshape to make sure I get the desired dimension of (25,5,5,6), then I want to extract one axis but its somehow not working properly.

@tf.function
def yolo_loss(y_true,y_pred):
    #mse = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.SUM)
    lambda_noobj = 0.5
    lambda_coord = 5
    y_pred = tf.reshape(y_pred,[batch_size,S1,S2,C+B*5])
    y_true = tf.reshape(y_true,[batch_size,S1,S2,6])
    exists_box = tf.reshape(y_true[...,0],[batch_size,S1,S2,1])
    ........

While the first reshape of y_true works perfectly fine I get an error for the exists_box line, to be precise:

      exists_box = tf.reshape(y_true[...,0],[batch_size,S1,S2,1])
Node: 'Reshape_2'
Input to reshape is a tensor with 425 values, but the requested shape has 625
     [[{{node Reshape_2}}]] [Op:__inference_train_function_44379]

The ellipsis in [...,0] should return me an object of size 25 *5 * 5 = 625 so I am confused why it says the object is of dimension 425. I also made sure that all arrays in y_train are of the same shape.


Solution

  • It seems that the error is caused by the last batch of your y_train dataset, which has shape (17, 5, 5, 6) (17 * 5 * 5 * 1 = 425). This occurs because when tensorflow batches your data, the last batch contains all the remaining elements, number of whose does not have to be your specified batch_size (in your case 25) - note that 2717 % 25 = 17.

    There are two things you can do:

    1. drop the remainding elements from the dataset; use this option if you are okay with losing a few examples from your traning data; if you are using tf.data.Dataset object, this can be done by providing drop_remainder=True in the batch method:
    dataset = dataset.batch(25, drop_remainder=True)
    
    1. change your loss function so that it can process input with different first dimension than 25; from your description it's not clear what your loss function does, so you'll have to figure this out by yourself.