Search code examples
pythontensorflowmetrics

Resetting tensorflow streaming metrics' variables


I have a bunch of streaming metrics (tf.metrics.accuracy and custom streaming micro, macro and weighted F1-scores).

During training, I get the kind of plot below (nevermind the overfitting).

This happens because to compute the validation set's metrics I call tf.local_variables_initializer to reset the metrics and only have a value for the validation set.

This implies 2 side effects:

  1. The spikes in the image
  2. In between validations, training metrics keep aggregating even if validation happens every 2 epochs

I could partially solve the situation by having different tensors hold each metric (train vs val). But It would not solve 2.

I therefore have 2 questions:

  • In your experience, is it a behavior you expect to see (or not? solution?)
  • Is there a way to have metrics stream only over the last n batches?

spkinging plot


Solution

  • This behaviour is expected if you reset the metrics in between training. The train metrics dont agregrate the validation metrics if they are two different ops. I will give an example on how to keep those metrics different and how to reset only one of them.


    A toy Example:

    logits = tf.placeholder(tf.int64, [2,3])
    labels = tf.Variable([[0, 1, 0], [1, 0, 1]])
    
    #create two different ops
    with tf.name_scope('train'):
       train_acc, train_acc_op = tf.metrics.accuracy(labels=tf.argmax(labels, 1), 
                                                     predictions=tf.argmax(logits,1))
    with tf.name_scope('valid'):
       valid_acc, valid_acc_op = tf.metrics.accuracy(labels=tf.argmax(labels, 1), 
                                                     predictions=tf.argmax(logits,1))
    

    Training:

    #initialize the local variables has it holds the variables used for metrics calculation.
    sess.run(tf.local_variables_initializer())
    sess.run(tf.global_variables_initializer())
    
    # initial state
    print(sess.run(train_acc, {logits:[[0,1,0],[1,0,1]]}))
    print(sess.run(valid_acc, {logits:[[0,1,0],[1,0,1]]}))
    
    #0.0
    #0.0
    

    The initial states are 0.0 as expected.

    Now calling the training op metrics:

    #training loop
    for _ in range(10):
        sess.run(train_acc_op, {logits:[[0,1,0],[1,0,1]]})  
    print(sess.run(train_acc, {logits:[[0,1,0],[1,0,1]]}))
    # 1.0
    print(sess.run(valid_acc, {logits:[[0,1,0],[1,0,1]]}))
    # 0.0
    

    Only the training accuracy got updated while the valid accuracy is still 0.0. Calling the valid ops:

    for _ in range(10):
        sess.run(valid_acc_op, {logits:[[0,1,0],[0,1,0]]}) 
    print(sess.run(valid_acc, {logits:[[0,1,0],[1,0,1]]}))
    #0.5
    print(sess.run(train_acc, {logits:[[0,1,0],[1,0,1]]}))
    #1.0
    

    Here the valid accuracy got updated to a new value while the training accuracy remained unchanged.

    Lets reset only the validation ops:

    stream_vars_valid = [v for v in tf.local_variables() if 'valid/' in v.name]
    sess.run(tf.variables_initializer(stream_vars_valid))
    
    print(sess.run(valid_acc, {logits:[[0,1,0],[1,0,1]]}))
    #0.0
    print(sess.run(train_acc, {logits:[[0,1,0],[1,0,1]]}))
    #1.0
    

    The valid accuracy got reset to zero while the training accuracy remained unchanged.