Search code examples
tensorflowtensorflow-probability

How to: TensorFlow-Probability custom loss that ignores NA values (or otherwise masks loss)


I seek to implement in TensorFlow-Probability a masked loss function, that can ignore NAs in the labels.

This is a well worn task for regular tensors. I cannot find an example for distributions.

My distributions are sized (batch, time-steps, outputs) (512, 251 days, 1 to 8 time series)

The traditional loss function given in examples is this using the distribution's log probability.

neg_log_likelihood <- function (x, rv_x) {
  -1*(rv_x %>% tfd_log_prob(x))
}

When I replace NAs with zeros, the model trains fine and converges. When I leave in NAs it produces NaN losses as expected.

I've experimented with many different permutations of tf$where to replace loss with 0, the label with 0, etc. In each of those cases the model stops training and loss stays near some constant. That's the case even when there's just a single NA in the labels.

neg_log_likelihood_missing <-  function (x, rv_x) {
  
  loss =     -1*(  rv_x %>% tfd_log_prob(x) ) 
  
  loss_nonan = tf$where( tf$math$is_finite(x) , loss, 0  )
  
  return( 
    loss_nonan
  )
}

My use of R here is incidental, and any examples in python or otherwise I can translate. If there's a correct way to this so that losses correctly back-propagate, I would greatly appreciate it.


Solution

  • If you are using gradient based inference, you may need the "double where" trick.

    While this gets you a correct value of y:

    y = computation(x)
    tf.where(is_nan(y), 0, y)
    

    ...the derivative of the tf.where can still have a nan.

    Instead write:

    safe_x = tf.where(is_unsafe(x), some_safe_x, x)
    y = computation(safe_x)
    tf.where(is_unsafe(x), 0, y)
    

    ...to get both a safe y out and a safe dy/dx.

    For the case you're considering, perhaps write:

    class MyMaskedDist(tfd.Distribution):
      ...
      def _log_prob(self, x):
        safe_x = tf.where(tf.is_nan(x), self.mode(), x)
        lp = compute_log_prob(safe_x)
        lp = tf.where(tf.is_nan(x), tf.zeros([], lp.dtype), lp)
        return lp