I'm working in R and trying to get started with neural networks, using the keras
package.
I'd like to use a custom loss function for training my NN. It's possible to do this by writing a the custom loss function as lossFn <- function(y_true, y_pred) { ... }
and passing it to the compile
method as model %>% compile(loss = lossFn, ...)
.
Now in order to use the gradient descent method of training the NN, the loss function needs to be differentiable. I understand that you'd usually accomplish this by restricting yourself to using backend functions in your loss function, e.g.
lossFn <- function(y_true, y_pred) {
K <- backend()
K$mean(K$square(y_true - y_pred), axis = 1L)
}
or something like that.
Now, my problem is that I cannot express my loss function this way; I need to use functions that aren't available in the backend.
So my idea was that I'd work out the gradient myself on paper, and then provide it to compile
as another argument, say compile(loss = lossFn, gradient = gradientFn, ...)
, with gradientFn
suitably defined.
The documentation for keras
(the R package!) does not indicate that this is possible. At the same time, it does not suggest it's not. And googling has turned up little that is relevant.
So my question is, is it possible?
An addendum: since Google has suggested that there are other training methods for NNs that do not rely on the gradient of the loss function, I should add I'm not too hung up on the specific training method. My ultimate goal isn't to manually supply the gradient of a custom loss function, it's to use a custom loss function to train the NN. The gradient is just a technical obstacle for me right now.
Thanks!
This is certainly possible in Keras, you'll just have to move up the stack a little and implement a train_step
method and then call optimizer$apply_gradients()
.
Chapter 7 in the Deep Learning with R book covers this use case: https://github.com/t-kalinowski/deep-learning-with-R-2nd-edition-code/blob/9f8b6d08dbb8d6565e4f5396e509aaea3e242b84/ch07.R#L608
Also, this keras guide may be useful, even though it's in Python and you're working in R. (The Python interface is very similar to the R interface). https://keras.io/guides/writing_a_training_loop_from_scratch/