Search code examples
tensorflow2.0jaxxla

looking for a tool to predict runtime of XLA-HLO computational graph


I'm looking for a tool to print the runtime when given the computational graph of XLA-HLO. I know there are HLO cost model (analytical model) for print the FLOPs of operator node for computational graph. But Is there any tool for print the expected runtime or any related value for runtime of XLA-HLO computational graph?

I need a source code of it or sample usage tool for it. Thanks :)


Solution

  • If you are using JAX, you can do this using the Ahead-of-time lowering and compilation APIs to get a sense of how resource-heavy a computation is. For example:

    import jax
    import numpy as np
    
    def f(M, x):
      for i in range(10):
        x = M @ x
      return x
    
    M = np.random.randn(1000, 1000)
    x = np.random.randn(1000)
    
    print(jax.jit(f).lower(M, x).compile().cost_analysis())
    
    [{'bytes accessed': 40080000.0,
      'bytes accessed operand 0 {}': 40000000.0,
      'bytes accessed operand 1 {}': 40000.0,
      'bytes accessed output {}': 40000.0,
      'flops': 20000000.0,
      'optimal_seconds': 0.0,
      'utilization operand 0 {}': 10.0,
      'utilization operand 1 {}': 10.0}]