Search code examples
pythontensorflowkerasloss-function

How to iterate through tensors in custom loss function?


I'm using keras with tensorflow backend. My goal is to query the batchsize of the current batch in a custom loss function. This is needed to compute values of the custom loss functions which depend on the index of particular observations. I like to make this clearer given the minimum reproducible examples below.

(BTW: Of course I could use the batch size defined for the training procedure and plugin it's value when defining the custom loss function, but there are some reasons why this can vary, especially if epochsize % batchsize (epochsize modulo batchsize) is unequal zero, then the last batch of an epoch has different size. I didn't found a suitable approach in stackoverflow, especially e. g. Tensor indexing in custom loss function and Tensorflow custom loss function in Keras - loop over tensor and Looping over a tensor because obviously the shape of any tensor can't be inferred when building the graph which is the case for a loss function - shape inference is only possible when evaluating given the data, which is only possible given the graph. Hence I need to tell the custom loss function to do something with particular elements along a certain dimension without knowing the length of the dimension.

(this is the same in all examples)

from keras.models import Sequential
from keras.layers import Dense, Activation

# Generate dummy data
import numpy as np
data = np.random.random((1000, 100))
labels = np.random.randint(2, size=(1000, 1))

model = Sequential()
model.add(Dense(32, activation='relu', input_dim=100))
model.add(Dense(1, activation='sigmoid'))

example 1: nothing special without issue, no custom loss

model.compile(optimizer='rmsprop',
              loss='binary_crossentropy',
              metrics=['accuracy'])    

# Train the model, iterating on the data in batches of 32 samples
model.fit(data, labels, epochs=10, batch_size=32)

(Output omitted, this runs perfectily fine)

example 2: nothing special, with a fairly simple custom loss

def custom_loss(yTrue, yPred):
    loss = np.abs(yTrue-yPred)
    return loss

model.compile(optimizer='rmsprop',
              loss=custom_loss,
              metrics=['accuracy'])

# Train the model, iterating on the data in batches of 32 samples
model.fit(data, labels, epochs=10, batch_size=32)

(Output omitted, this runs perfectily fine)

example 3: the issue

def custom_loss(yTrue, yPred):
    print(yPred) # Output: Tensor("dense_2/Sigmoid:0", shape=(?, 1), dtype=float32)
    n = yPred.shape[0]
    for i in range(n): # TypeError: __index__ returned non-int (type NoneType)
        loss = np.abs(yTrue[i]-yPred[int(i/2)])
    return loss

model.compile(optimizer='rmsprop',
              loss=custom_loss,
              metrics=['accuracy'])

# Train the model, iterating on the data in batches of 32 samples
model.fit(data, labels, epochs=10, batch_size=32)

Of course the tensor has not shape info yet which can't be inferred when building the graph, only at training time. Hence for i in range(n) rises an error. Is there any way to perform this?

The traceback of the output: enter image description here

-------

BTW here's my true custom loss function in case of any questions. I skipped it above for clarity and simplicity.

def neg_log_likelihood(yTrue,yPred):
    yStatus = yTrue[:,0]
    yTime = yTrue[:,1]    
    n = yTrue.shape[0]    
    for i in range(n):
        s1 = K.greater_equal(yTime, yTime[i])
        s2 = K.exp(yPred[s1])
        s3 = K.sum(s2)
        logsum = K.log(y3)
        loss = K.sum(yStatus[i] * yPred[i] - logsum)
    return loss

Here's an image of the partial negative log-likelihood of the cox proportional harzards model.

enter image description here

This is to clarify a question in the comments to avoid confusion. I don't think it is necessary to understand this in detail to answer the question.


Solution

  • As usual, don't loop. There are severe performance drawbacks and also bugs. Use only backend functions unless totally unavoidable (usually it's not unavoidable)


    Solution for example 3:

    So, there is a very weird thing there...

    Do you really want to simply ignore half of your model's predictions? (Example 3)

    Assuming this is true, just duplicate your tensor in the last dimension, flatten and discard half of it. You have the exact effect you want.

    def custom_loss(true, pred):
        n = K.shape(pred)[0:1]
    
        pred = K.concatenate([pred]*2, axis=-1) #duplicate in the last axis
        pred = K.flatten(pred)                  #flatten 
        pred = K.slice(pred,                    #take only half (= n samples)
                       K.constant([0], dtype="int32"), 
                       n) 
    
        return K.abs(true - pred)
    

    Solution for your loss function:

    If you have sorted times from greater to lower, just do a cumulative sum.

    Warning: If you have one time per sample, you cannot train with mini-batches!!!
    batch_size = len(labels)

    It makes sense to have time in an additional dimension (many times per sample), as is done in recurrent and 1D conv netoworks. Anyway, considering your example as expressed, that is shape (samples_equal_times,) for yTime:

    def neg_log_likelihood(yTrue,yPred):
        yStatus = yTrue[:,0]
        yTime = yTrue[:,1]    
        n = K.shape(yTrue)[0]    
    
    
        #sort the times and everything else from greater to lower:
        #obs, you can have the data sorted already and avoid doing it here for performance
    
        #important, yTime will be sorted in the last dimension, make sure its (None,) in this case
        # or that it's (None, time_length) in the case of many times per sample
        sortedTime, sortedIndices = tf.math.top_k(yTime, n, True)    
        sortedStatus = K.gather(yStatus, sortedIndices)
        sortedPreds = K.gather(yPred, sortedIndices)
    
        #do the calculations
        exp = K.exp(sortedPreds)
        sums = K.cumsum(exp)  #this will have the sum for j >= i in the loop
        logsums = K.log(sums)
    
        return K.sum(sortedStatus * sortedPreds - logsums)