Search code examples
tensorflowloggingtensorflow-estimatorverbositytpu

Mute logging in TPU estimator


I am using BERT run_classifier and tensorflow TPUEstimator, each time I train my model or I predict using the estimator predictor, I get too many logging information printed on my screen. How can I get rid of this information. The following line is printed million times:

INFO:tensorflow:Dequeue next (1) batch(es) of data from outfeed.
I0423 15:45:17.093261 140624241985408 tpu_estimator.py:540] Dequeue next (1) batch(es) of data from outfeed.

Also the following line is being written on my screen million times (although there is no problem and the model is trained using TPU properly)

E0423 15:44:54.258747 140624241985408 tpu.py:330] Operation of type Placeholder (module_apply_tokens/bert/encoder/layer_6/attention/output/dense/kernel) is not supported on the TPU. Execution will fail if this op is used in the graph. 
ERROR:tensorflow:Operation of type Placeholder (module_apply_tokens/bert/encoder/layer_6/attention/output/dense/bias) is not supported on the TPU. Execution will fail if this op is used in the graph. 

This is the code which produces this verbosity:

from bert import run_classifier
estimator = tf.contrib.tpu.TPUEstimator(
  use_tpu=True,
  model_fn=model_fn,
  config=get_run_config(OUTPUT_DIR),
  train_batch_size=TRAIN_BATCH_SIZE,
  eval_batch_size=EVAL_BATCH_SIZE,
  predict_batch_size=PREDICT_BATCH_SIZE,
)

input_features = run_classifier.convert_examples_to_features(prediction_examples, label_list, MAX_SEQ_LENGTH, tokenizer)
predict_input_fn = run_classifier.input_fn_builder(features=input_features, seq_length=MAX_SEQ_LENGTH, is_training=False, drop_remainder=True)
predictions = estimator.predict(predict_input_fn)

How can I ask the model not to print them?


Solution

  • You should be able to set the logging verbosity level with

    tf.logging.set_verbosity(v)
    

    at the first line of your main() method, where verbosity level v could be:

    _level_names = {
      FATAL: 'FATAL',
      ERROR: 'ERROR',
      WARN: 'WARN',
      INFO: 'INFO',
      DEBUG: 'DEBUG',
    }
    

    where v=tf.logging.FATAL will print the least amount of logs.