I am new to DL and Keras and currently I am trying to implement a sobel-filter-based custom loss function in Keras.
The idea is to calculate the mean squared loss of a sobel filtered prediction and a sobel filtered ground truth image.
So far, my custom loss function looks like this:
from scipy import ndimage
def mse_sobel(y_true, y_pred):
for i in range (0, y_true.shape[0]):
dx_true = ndimage.sobel(y_true[i,:,:,:], 1)
dy_true = ndimage.sobel(y_true[i,:,:,:], 2)
mag_true[i,:,:,:] = np.hypot(dx_true, dy_true)
mag_true[i,:,:,:] *= 1.0 / np.max(mag_true[i,:,:,:])
dx_pred = ndimage.sobel(y_pred[i,:,:,:], 1)
dy_pred = ndimage.sobel(y_pred[i,:,:,:], 2)
mag_pred[i,:,:,:] = np.hypot(dx_pred, dy_pred)
mag_pred[i,:,:,:] *= 1.0 / np.max(mag_pred[i,:,:,:])
return(K.mean(K.square(mag_pred - mag_true), axis=-1))
Using this loss function leads to this error:
in mse_sobel
for i in range (0, y_true.shape[0]):
TypeError: __index__ returned non-int (type NoneType)
Using the debugger I found out, that y_true.shape
only returns None
- fine. But when I replace y_true.shape
with for example 1
such that it looks like this for i in range (0,1):
, another error occurs:
in sobel
axis = _ni_support._check_axis(axis, input.ndim)
in _check_axis
raise ValueError('invalid axis')
ValueError: invalid axis
Here, I am not sure about why the axis seems to be invalid?
Can anyone help me figure out how to implement that loss function? Thank you very much for your help!
Losses must be made with tensor operations, using the keras backend, or tensorflow/theano/cntk functions. This is the only way to keep backpropagation. Using numpy, scipy etc. breaks the graph.
Let's import the keras backend:
import keras.backend as K
Defining the filters:
#this contains both X and Y sobel filters in the format (3,3,1,2)
#size is 3 x 3, it considers 1 input channel and has two output channels: X and Y
sobelFilter = K.variable([[[[1., 1.]], [[0., 2.]],[[-1., 1.]]],
[[[2., 0.]], [[0., 0.]],[[-2., 0.]]],
[[[1., -1.]], [[0., -2.]],[[-1., -1.]]]])
Here, a function that repeats the filters for each input channel, in case your images are RGB or have more than 1 channel. This will just replicate the sobel filters for each input channel: (3,3,inputChannels, 2)
def expandedSobel(inputTensor):
#this considers data_format = 'channels_last'
inputChannels = K.reshape(K.ones_like(inputTensor[0,0,0,:]),(1,1,-1,1))
#if you're using 'channels_first', use inputTensor[0,:,0,0] above
return sobelFilter * inputChannels
And this is the loss function:
def sobelLoss(yTrue,yPred):
#get the sobel filter repeated for each input channel
filt = expandedSobel(yTrue)
#calculate the sobel filters for yTrue and yPred
#this generates twice the number of input channels
#a X and Y channel for each input channel
sobelTrue = K.depthwise_conv2d(yTrue,filt)
sobelPred = K.depthwise_conv2d(yPred,filt)
#now you just apply the mse:
return K.mean(K.square(sobelTrue - sobelPred))
Apply this loss in the model:
model.compile(loss=sobelLoss, optimizer = ....)
My experience shows that calculating the unified sobel filter sqrt(X² + Y²)
brings terrible results and the resulting images sound like chess boards. But if you do want it:
def squareSobelLoss(yTrue,yPred):
#same beginning as the other loss
filt = expandedSobel(yTrue)
squareSobelTrue = K.square(K.depthwise_conv2d(yTrue,filt))
squareSobelPred = K.square(K.depthwise_conv2d(yPred,filt))
#here, since we've got 6 output channels (for an RGB image)
#let's reorganize in order to easily sum X² and Y²: change (h,w,6) to (h,w,3,2)
#caution: this method of reshaping only works in tensorflow
#if you do need this in other backends, let me know
newShape = K.shape(squareSobelTrue)
newShape = K.concatenate([newShape[:-1],
#sum the last axis (the one that is 2 above, representing X² and Y²)
squareSobelTrue = K.sum(K.reshape(squareSobelTrue,newShape),axis=-1)
squareSobelPred = K.sum(K.reshape(squareSobelPred,newShape),axis=-1)
#since both previous values are already squared, maybe we shouldn't square them again?
#but you can apply the K.sqrt() in both, and then make the difference,
#and then another square, it's up to you...
return K.mean(K.abs(squareSobelTrue - squareSobelPred))