I'm looking for a tool to print the runtime when given the computational graph of XLA-HLO. I know there are HLO cost model (analytical model) for print the FLOPs of operator node for computational graph. But Is there any tool for print the expected runtime or any related value for runtime of XLA-HLO computational graph?
I need a source code of it or sample usage tool for it. Thanks :)
If you are using JAX, you can do this using the Ahead-of-time lowering and compilation APIs to get a sense of how resource-heavy a computation is. For example:
import jax
import numpy as np
def f(M, x):
for i in range(10):
x = M @ x
return x
M = np.random.randn(1000, 1000)
x = np.random.randn(1000)
print(jax.jit(f).lower(M, x).compile().cost_analysis())
[{'bytes accessed': 40080000.0,
'bytes accessed operand 0 {}': 40000000.0,
'bytes accessed operand 1 {}': 40000.0,
'bytes accessed output {}': 40000.0,
'flops': 20000000.0,
'optimal_seconds': 0.0,
'utilization operand 0 {}': 10.0,
'utilization operand 1 {}': 10.0}]