I'm trying to add summaries to my TensorFlow graphs which run asynchronously. I've got everything working in the single threaded case but the summaries seem to disappear once I go to multithreading. Here's a toy example of what I'm trying to do
import tensorflow as tf # 1.1.0
import threading
class Worker:
def __init__(self):
self.x = tf.Variable([1, -2, 3], tf.float32, name='x')
self.y = tf.Variable([-1, 2, -3], tf.float32, name='y')
self.dot_product = tf.reduce_sum(tf.multiply(self.x, self.y))
tf.summary.scalar("Dot_Product", self.dot_product)
def work(self):
for i in range(10):
SESS.run(self.dot_product)
# Write summary
summary_str = SESS.run(tf.summary.merge_all())
WRITER.add_summary(summary_str, i)
WRITER.flush()
COORD = tf.train.Coordinator()
SESS = tf.Session()
WRITER = tf.summary.FileWriter(SUMMARY_DIR, SESS.graph)
# Single Thread case
w = Worker()
SESS.run(tf.global_variables_initializer())
print(tf.get_collection(tf.GraphKeys.SUMMARIES))
w.work()
This works fine. However, if I go multithreaded:
# Multi-thread case
workers = [Worker() for i in range(4)]
SESS.run(tf.global_variables_initializer())
print(tf.get_collection(tf.GraphKeys.SUMMARIES))
worker_threads = []
for worker in workers:
job = lambda: worker.work()
t = threading.Thread(target=job)
t.start()
worker_threads.append(t)
COORD.join(worker_threads)
Whenever tf.summary.merge_all()
is called I get an error like this due to the fact that it can't see any summaries:
Exception in thread Thread-2:
Traceback (most recent call last):
File "/usr/lib/python3.5/threading.py", line 914, in _bootstrap_inner
self.run()
File "/usr/lib/python3.5/threading.py", line 862, in run
self._target(*self._args, **self._kwargs)
File "/home/anjum/PycharmProjects/junk.py", line 43, in <lambda>
job = lambda: worker.work()
File "/home/anjum/PycharmProjects/junk.py", line 22, in work
summary_str = SESS.run(tf.summary.merge_all())
File "/usr/local/lib/python3.5/dist-
packages/tensorflow/python/client/session.py", line 778, in run
run_metadata_ptr)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 969, in _run
fetch_handler = _FetchHandler(self._graph, fetches, feed_dict_string)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 408, in __init__
self._fetch_mapper = _FetchMapper.for_fetch(fetches)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 227, in for_fetch
(fetch, type(fetch)))
TypeError: Fetch argument None has invalid type <class 'NoneType'>
If I put print(tf.get_collection(tf.GraphKeys.SUMMARIES))
inside work()
, an empty list is returned. So it means that my summaries are getting lost somewhere.
Could someone please explain how to properly use summaries with multithreading?
I think I've figured it out - the summaries have to be merged like this instead. I'm not 100% sure why TensorFlow is so fussy about this
class Worker:
def __init__(self):
self.x = tf.Variable([1, -2, 3], tf.float32, name='x')
self.y = tf.Variable([-1, 2, -3], tf.float32, name='y')
self.dot_product = tf.reduce_sum(tf.multiply(self.x, self.y))
tf.summary.scalar("Dot_Product", self.dot_product)
self.summarise = tf.summary.merge_all()
def work(self):
for i in range(10):
SESS.run(self.dot_product)
# Write summary
summary = SESS.run(self.summarise)
WRITER.add_summary(summary, i)
WRITER.flush()