I am training a neural network using Flax. My training data has a significant number of nans in the outputs. I want to ignore these and only use the non-nan values for training. To achieve this, I have tried to use jnp.nanmean
to compute the losses, i.e.:
def nanloss(params, inputs, targets):
pred = model.apply(params, inputs)
return jnp.nanmean((pred - targets) ** 2)
def train_step(state, inputs, targets):
loss, grads = jax.value_and_grad(nanloss)(state.params, inputs, targets)
state = state.apply_gradients(grads=grads)
return state, loss
However, after one training step the loss is nan.
Is what I am trying to achieve possible? If so, how can I fix this?
I suspect you are hitting the issue discussed here: JAX FAQ: gradients contain NaN where using where
. You've handled the NaNs in the computation itself, but they're still sneaking into the gradient due to how autodiff is implemented.
If this is in fact the issue, you can fix this by filtering the values before computing the loss; for example like this:
def nanloss(params, inputs, targets):
pred = model.apply(params, inputs)
mask = jnp.isnan(pred) | jnp.isnan(targets)
pred = jnp.where(mask, 0, pred)
targets = jnp.where(mask, 0, targets)
return jnp.mean((pred - targets) ** 2, where=~mask)