Search code examples
pythontensorflowdeep-learningprofiler

How to calculate FLOPs of transformer in tensorflow?


I know that

    flops = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.float_operation())

can calculate the FLOPs. But where can I find the graph of transformer?
Please help me.


Solution

  • The graph should be the tf.Graph of the model that you are profiling. See here for more information about Tensorflow graphs and here for Tensorflow Profiler tutorials and examples.