Search code examples
tensorflowkerasdeep-learningkeras-layerloss-function

Custom loss function in keras involving huge matrix multiplication


I am having difficulty writing the custom loss function in Keras. I have layer weights "W" and a matrix "M". I want to do the following operation trace((W * M) * W') to compute my loss function. Trace is the sum of the diagonal elements. In numpy, I would have done the following:

np.trace(np.dot(np.dot(W,M),W.T))) or 

def custom_regularizer(W,M):
    sum_reg = 0
    for i in range(W.shape[1]):
        for j in range(i,W.shape[1]):
            vector = W[:,i] - W[:,j]
            sum_reg = sum_reg + M[i,j] * (LA.norm(vector)**2)
    return sum_reg

For keras, I have written following loss function

def custom_loss(W):

  def lossFunction(y_true,y_pred):    
    loss = tf.trace(K.dot(K.dot(W,K.constant(M)),K.transpose(W)))
    return loss

return lossFunction

The problem is that keras is computing the whole outer matrix whose dimension is 200000 * 200000, giving memory error. Is there any way by which I can just get the sum of diagonal elements without doing the whole matrix computation.

How to do the same as keras loss function ?


Solution

  • If you follow some clever tricks to compute the trace, you should not be running out of memory. For instance, you can refer to this.