Search code examples
pythonmultithreadingtensorflowtensorboard

TensorFlow summaries with threading


I'm trying to add summaries to my TensorFlow graphs which run asynchronously. I've got everything working in the single threaded case but the summaries seem to disappear once I go to multithreading. Here's a toy example of what I'm trying to do

import tensorflow as tf  # 1.1.0
import threading


class Worker:
    def __init__(self):
        self.x = tf.Variable([1, -2, 3], tf.float32, name='x')
        self.y = tf.Variable([-1, 2, -3], tf.float32, name='y')
        self.dot_product = tf.reduce_sum(tf.multiply(self.x, self.y))
        tf.summary.scalar("Dot_Product", self.dot_product)

    def work(self):
        for i in range(10):
            SESS.run(self.dot_product)

            # Write summary
            summary_str = SESS.run(tf.summary.merge_all())
            WRITER.add_summary(summary_str, i)
            WRITER.flush()

COORD = tf.train.Coordinator()
SESS = tf.Session()
WRITER = tf.summary.FileWriter(SUMMARY_DIR, SESS.graph)

# Single Thread  case
w = Worker()
SESS.run(tf.global_variables_initializer())
print(tf.get_collection(tf.GraphKeys.SUMMARIES))
w.work()

This works fine. However, if I go multithreaded:

# Multi-thread case
workers = [Worker() for i in range(4)]
SESS.run(tf.global_variables_initializer())
print(tf.get_collection(tf.GraphKeys.SUMMARIES))

worker_threads = []
for worker in workers:
    job = lambda: worker.work()
    t = threading.Thread(target=job)
    t.start()
    worker_threads.append(t)
COORD.join(worker_threads)

Whenever tf.summary.merge_all() is called I get an error like this due to the fact that it can't see any summaries:

Exception in thread Thread-2:
Traceback (most recent call last):
  File "/usr/lib/python3.5/threading.py", line 914, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.5/threading.py", line 862, in run
    self._target(*self._args, **self._kwargs)
  File "/home/anjum/PycharmProjects/junk.py", line 43, in <lambda>
    job = lambda: worker.work()
  File "/home/anjum/PycharmProjects/junk.py", line 22, in work
    summary_str = SESS.run(tf.summary.merge_all())
  File "/usr/local/lib/python3.5/dist-
packages/tensorflow/python/client/session.py", line 778, in run
    run_metadata_ptr)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 969, in _run
fetch_handler = _FetchHandler(self._graph, fetches, feed_dict_string)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 408, in __init__
self._fetch_mapper = _FetchMapper.for_fetch(fetches)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 227, in for_fetch
(fetch, type(fetch)))
TypeError: Fetch argument None has invalid type <class 'NoneType'>

If I put print(tf.get_collection(tf.GraphKeys.SUMMARIES)) inside work(), an empty list is returned. So it means that my summaries are getting lost somewhere.

Could someone please explain how to properly use summaries with multithreading?


Solution

  • I think I've figured it out - the summaries have to be merged like this instead. I'm not 100% sure why TensorFlow is so fussy about this

    class Worker:
        def __init__(self):
            self.x = tf.Variable([1, -2, 3], tf.float32, name='x')
            self.y = tf.Variable([-1, 2, -3], tf.float32, name='y')
            self.dot_product = tf.reduce_sum(tf.multiply(self.x, self.y))
            tf.summary.scalar("Dot_Product", self.dot_product)
    
            self.summarise = tf.summary.merge_all()
    
        def work(self):
            for i in range(10):
                SESS.run(self.dot_product)
    
                # Write summary
                summary = SESS.run(self.summarise)
                WRITER.add_summary(summary, i)
                WRITER.flush()