Search code examples
pythontensorflowgraphmetrics

Updating internal state for TensorFlow custom metrics (aka using non-update_state vars in metric calculation)


Versions: python 3.8.2 (I've also tried on 3.6.8, but I don't think the python version matters here), tensorflow 2.3.0, numpy 1.18.5

I'm training a model for a classification problem with a sparse labels tensor. How would I go about defining a metric that counts the number of times that the "0" label has appeared up until that point? What I'm trying to do in the code example below is to store all the labels that the metric has seen in an array and constantly concatenate the existing array with the new y_true every time update_state is called. (I know I could just store a count variable and use +=, but in the actual usage scenario, concatenating is ideal and memory is not an issue.) Here's minimal code to reproduce the problem:

import tensorflow as tf

class ZeroLabels(tf.keras.metrics.Metric):
    """Accumulates a list of all y_true sparse categorical labels (ints) and calculates the number of times the '0' label has appeared."""
    def __init__(self, *args, **kwargs):
        super(ZeroLabels, self).__init__(name="ZeroLabels")
        self.labels = self.add_weight(name="labels", shape=(), initializer="zeros", dtype=tf.int32)

    def update_state(self, y_true, y_pred, sample_weight=None):
        """I'm using sparse categorical crossentropy, so labels are 1D array of integers."""
        if self.labels.shape == (): # if this is the first time update_state is being called
            self.labels = y_true
        else:
            self.labels = tf.concat((self.labels, y_true), axis=0)

    def result(self):
        return tf.reduce_sum(tf.cast(self.labels == 0, dtype=tf.int32))

    def reset_states(self):
        self.labels = tf.constant(0, dtype=tf.int32)

This code works on its own, but it throws the following error when I try to train a model using this metric:

TypeError: An op outside of the function building code is being passed
a "Graph" tensor. It is possible to have Graph tensors
leak out of the function building context by including a
tf.init_scope in your function building code.
For example, the following function will fail:
  @tf.function
  def has_init_scope():
    my_constant = tf.constant(1.)
    with tf.init_scope():
      added = my_constant * 2

I thought this might have something to do with the fact that self.labels isn't directly part of the graph when update_state is called. Here are some other things I've tried:

  • storing a tf.int32, shape=() count variable and incrementing that instead of concatenating the new labels
  • converting everything to numpy using .numpy() and concatenating those instead (I was hoping to force TensorFlow to not use the graph)
  • using try and except blocks with the numpy conversion above
  • creating an entirely new class (rather than subclassing tf.keras.metrics.Metric) that exclusively uses numpy where possible, but this approach results in some loading issues, even when I use custom_objects in tf.keras.models.load_model
  • using the @tf.autograph.experimental.do_not_convert decorator on all methods
  • modifying a global variable rather than an attribute and using global keyword
  • using non-tensorflow attributes (not using self.labels = self.add_weight...)

If it helps, here's a more general version of this question: How can we incorporate tensors that aren't passed in as parameters to update_state in the update_state calculation? Any help would be greatly appreciated. Thank you in advance!


Solution

  • The main problem was the first iteration assignment, when there is not an initial value:

    if self.labels.shape == ():
        self.labels = y_true
    else:
        self.labels = tf.concat((self.labels, y_true), axis=0)
    

    Inside the if block, your variable 'labels' defined in the constructor just disappears and is replaced by a tf.Tensor object (y_true). So, you have to use tf.Variable methods (assign, add_assing) to modify its content but keeping the object. Moreover, to be able to change a tf.variable shape, you have to create it in such a way that it will allow you to have an undefined shape, in this case: (None,1), because you're concatenating on axis=0.

    So:

    class ZeroLabels(tf.keras.metrics.Metric):
        def __init__(self, *args, **kwargs):
            super(ZeroLabels, self).__init__(name="ZeroLabels")
    
            # Define a variable with unknown shape. This will allow you have dynamically sized variables (validate_shape=False)
            self.labels = tf.Variable([], shape=(None,), validate_shape=False)
    
        def update_state(self, y_true, y_pred, sample_weight=None):
            # On update method, just assign as new value the prevoius one joined with y_true
            self.labels.assign(tf.concat([self.labels.value(), y_true[:,0]], axis=0))
    
        def result(self):
            return tf.reduce_sum(tf.cast(self.labels.value() == 0, dtype=tf.int32))
    
        def reset_states(self):
            # To reset the metric, assign again an empty tensor
            self.labels.assign([])
    

    But, if you only one to count the 0s of the dataset, I suggest you to have an integer variable which will count these elements, because after every batch proccessed, labels array will increase its size and getting the sum of all its elements will take more and more time, slowing down your training.

    class ZeroLabels_2(tf.keras.metrics.Metric):
        """Accumulates a list of all y_true sparse categorical labels (ints) and calculates the number of times the '0' label has appeared."""
        def __init__(self, *args, **kwargs):
            super(ZeroLabels_2, self).__init__(name="ZeroLabels")
    
            # Define an integer variable
            self.labels = tf.Variable(0, dtype=tf.int32)
    
        def update_state(self, y_true, y_pred, sample_weight=None):
            # Increase variable with every batch
            self.labels.assign_add(tf.cast(tf.reduce_sum(tf.cast(y_true == 0, dtype=tf.int32)), dtype=tf.int32 ))
    
        def result(self):
            # Simply return variable's content
            return self.labels.value()
    
        def reset_states(self):
            self.labels.assign(0)
    

    I hope this can help you (and apologies for English level)