Search code examples
tensorboardtensorflow2.0

Tensorboard Graph: Profiler session started


I wanted to show my network graph on tensorboard using tensorflow 2. I followed this tutorial and I did a code that was something like this:

for epoch in range(epochs):
    # Bracket the function call with
    # tf.summary.trace_on() and tf.summary.trace_export().
    tf.summary.trace_on(graph=True, profiler=True)
    # Call only one tf.function when tracing.
    z = train_step(x, y)
    with writer.as_default():
        tf.summary.trace_export(name="train_graph", step=0, profiler_outdir=logdir)

When doing that, I got the message Profiler session started. several times. When I open tensorboard, of course, the Graph said an error occurred and couldn't show anything.


Solution

  • I found the response here.

    Actually, you can enable graph export in v2. You'll need to call tf.summary.trace_on() before the code you want to trace the graph for (e.g. L224 if you just want the train step), and then call tf.summary.trace_off() after the code completes. Since you only need one trace of the graph, I would recommend wrapping these calls with if global_step_val == 0: so that you don't produce traces every step.

    Actually, to create the graph it is necessary to do the trace just once and makes no sense of doing it at each epoch. The solution is just to check before calling the trace just once like:

    for epoch in range(epochs):
        if epoch == 0:
            tf.summary.trace_on(graph=True, profiler=True)
        z = train_step(x, y)
        if epoch == 0:
            with writer.as_default():
                tf.summary.trace_export(name="train_graph", step=0, profiler_outdir=logdir)
    

    I personally like more this decorator idea:

    def run_once(f):
        def wrapper(*args, **kwargs):
            if not wrapper.has_run:
                wrapper.has_run = True
                return f(*args, **kwargs)
        wrapper.has_run = False
        return wrapper
    
    @run_once
    def _start_graph_tensorflow(self):
        tf.summary.trace_on(graph=True, profiler=True)  # https://www.tensorflow.org/tensorboard/graphs
    
    @run_once
    def _end_graph_tensorflow(self):
        with self.graph_writer.as_default():
            tf.summary.trace_export(name="graph", step=0, profiler_outdir=self.graph_writer_logdir)
    
    for epoch in range(epochs):
        _start_graph_tensorflow()
        z = train_step(x, y)
        _end_graph_tensorflow()