Search code examples
pythonscipykerassobel

How to implement custom sobel-filter-based loss function using Keras


I am new to DL and Keras and currently I am trying to implement a sobel-filter-based custom loss function in Keras.

The idea is to calculate the mean squared loss of a sobel filtered prediction and a sobel filtered ground truth image.

So far, my custom loss function looks like this:

from scipy import ndimage

def mse_sobel(y_true, y_pred):

    for i in range (0, y_true.shape[0]):
        dx_true = ndimage.sobel(y_true[i,:,:,:], 1)
        dy_true = ndimage.sobel(y_true[i,:,:,:], 2)
        mag_true[i,:,:,:] = np.hypot(dx_true, dy_true)
        mag_true[i,:,:,:] *= 1.0 / np.max(mag_true[i,:,:,:])

        dx_pred = ndimage.sobel(y_pred[i,:,:,:], 1)
        dy_pred = ndimage.sobel(y_pred[i,:,:,:], 2)
        mag_pred[i,:,:,:] = np.hypot(dx_pred, dy_pred)
        mag_pred[i,:,:,:] *= 1.0 / np.max(mag_pred[i,:,:,:])

    return(K.mean(K.square(mag_pred - mag_true), axis=-1))

Using this loss function leads to this error:

in mse_sobel
for i in range (0, y_true.shape[0]):
TypeError: __index__ returned non-int (type NoneType)

Using the debugger I found out, that y_true.shape only returns None - fine. But when I replace y_true.shape with for example 1 such that it looks like this for i in range (0,1):, another error occurs:

in sobel
axis = _ni_support._check_axis(axis, input.ndim)

in _check_axis
raise ValueError('invalid axis')
ValueError: invalid axis

Here, I am not sure about why the axis seems to be invalid?

Can anyone help me figure out how to implement that loss function? Thank you very much for your help!


Solution

  • Losses must be made with tensor operations, using the keras backend, or tensorflow/theano/cntk functions. This is the only way to keep backpropagation. Using numpy, scipy etc. breaks the graph.

    Let's import the keras backend:

    import keras.backend as K
    

    Defining the filters:

    #this contains both X and Y sobel filters in the format (3,3,1,2)
    #size is 3 x 3, it considers 1 input channel and has two output channels: X and Y
    sobelFilter = K.variable([[[[1.,  1.]], [[0.,  2.]],[[-1.,  1.]]],
                          [[[2.,  0.]], [[0.,  0.]],[[-2.,  0.]]],
                          [[[1., -1.]], [[0., -2.]],[[-1., -1.]]]])
    

    Here, a function that repeats the filters for each input channel, in case your images are RGB or have more than 1 channel. This will just replicate the sobel filters for each input channel: (3,3,inputChannels, 2):

    def expandedSobel(inputTensor):
    
        #this considers data_format = 'channels_last'
        inputChannels = K.reshape(K.ones_like(inputTensor[0,0,0,:]),(1,1,-1,1))
        #if you're using 'channels_first', use inputTensor[0,:,0,0] above
    
        return sobelFilter * inputChannels
    

    And this is the loss function:

    def sobelLoss(yTrue,yPred):
    
        #get the sobel filter repeated for each input channel
        filt = expandedSobel(yTrue)
    
        #calculate the sobel filters for yTrue and yPred
        #this generates twice the number of input channels 
        #a X and Y channel for each input channel
        sobelTrue = K.depthwise_conv2d(yTrue,filt)
        sobelPred = K.depthwise_conv2d(yPred,filt)
    
        #now you just apply the mse:
        return K.mean(K.square(sobelTrue - sobelPred))
    

    Apply this loss in the model:

    model.compile(loss=sobelLoss, optimizer = ....)
    

    My experience shows that calculating the unified sobel filter sqrt(X² + Y²) brings terrible results and the resulting images sound like chess boards. But if you do want it:

    def squareSobelLoss(yTrue,yPred):
    
        #same beginning as the other loss
        filt = expandedSobel(yTrue)
        squareSobelTrue = K.square(K.depthwise_conv2d(yTrue,filt))
        squareSobelPred = K.square(K.depthwise_conv2d(yPred,filt))
    
        #here, since we've got 6 output channels (for an RGB image)
        #let's reorganize in order to easily sum X² and Y²: change (h,w,6) to (h,w,3,2)
        #caution: this method of reshaping only works in tensorflow
        #if you do need this in other backends, let me know
        newShape = K.shape(squareSobelTrue)
        newShape = K.concatenate([newShape[:-1],
                                  newShape[-1:]//2,
                                  K.variable([2],dtype='int32')])
    
        #sum the last axis (the one that is 2 above, representing X² and Y²)                      
        squareSobelTrue = K.sum(K.reshape(squareSobelTrue,newShape),axis=-1)
        squareSobelPred = K.sum(K.reshape(squareSobelPred,newShape),axis=-1)
    
        #since both previous values are already squared, maybe we shouldn't square them again? 
        #but you can apply the K.sqrt() in both, and then make the difference, 
        #and then another square, it's up to you...    
        return K.mean(K.abs(squareSobelTrue - squareSobelPred))