Search code examples
tensorflowkerasdeep-learningloss-function

What is the difference between Kelas MSE loss and my own loss function?


I am new about Keras. I tried to make a custom loss function in Keras. But something is wrong in my code. Keras worked but the estimation result is strange. Where should I change the code?

I simply tried to implement MSE as a custom loss function.

This is a loss function part.

def loss_function(ytrue, ypred):

    qx_true = ytrue[:, 0]
    qx_pred = ytrue[:, 0]
    qy_true = ytrue[:, 1]
    qy_pred = ytrue[:, 1]
    qz_true = ytrue[:, 2]
    qz_pred = ytrue[:, 2]
    qw_true = ytrue[:, 3]
    qw_pred = ytrue[:, 3]
    tx_true = ytrue[:, 4]
    tx_pred = ypred[:, 4]
    ty_true = ytrue[:, 5]
    ty_pred = ypred[:, 5]
    tz_true = ytrue[:, 6]
    tz_pred = ypred[:, 6]

    loss = ((tx_true - tx_pred) * (tx_true - tx_pred) 
        + (ty_true - ty_pred) * (ty_true - ty_pred) 
        + (tz_true - tz_pred) * (tz_true - tz_pred) 
        + (qx_true - qx_pred) * (qx_true - qx_pred) 
        + (qy_true - qy_pred) * (qy_true - qy_pred) 
        + (qz_true - qz_pred) * (qz_true - qz_pred) 
        + (qw_true - qw_pred) * (qw_true - qw_pred)) / 7

    return loss

and this is a calling loss function part

    model.add(Dense(7, name='output'))
    model.compile(loss=loss_function, optimizer=keras.optimizers.Adam())

When I tried Keras original loss function, it works

    model.add(Dense(7, name='output'))
    model.compile(loss=keras.losses.MSE, optimizer=keras.optimizers.Adam())

The input of the loss function is three parameters of translation and four parameters of the quaternion. When I tried to use keras.losses.MSE, it worked, and I am trying to do the same things.

Where is the wrong part? Thanks


Solution

  • I believe this

    qx_true = ytrue[:, 0]
    qx_pred = ytrue[:, 0]
    qy_true = ytrue[:, 1]
    qy_pred = ytrue[:, 1]
    qz_true = ytrue[:, 2]
    qz_pred = ytrue[:, 2]
    qw_true = ytrue[:, 3]
    qw_pred = ytrue[:, 3]
    

    should be

    qx_true = ytrue[:, 0]
    qx_pred = ypred[:, 0]
    qy_true = ytrue[:, 1]
    qy_pred = ypred[:, 1]
    qz_true = ytrue[:, 2]
    qz_pred = ypred[:, 2]
    qw_true = ytrue[:, 3]
    qw_pred = ypred[:, 3]