I recently implemented a two-layer GRU network in Jax and was disappointed by its performance (it was unusable).
This is my minimal working example and the output was created on Google Colab with GPU-runtime. notebook in colab
import flax.linen as jnn
import jax
import torch
import torch.nn as tnn
import numpy as np
import jax.numpy as jnp
def keyGen(seed):
key1 = jax.random.PRNGKey(seed)
while True:
key1, key2 = jax.random.split(key1)
yield key2
key = keyGen(1)
seq_length = 1000
in_features = 6
out_features = 4
batch_size = 8
class RNN_jax(jnn.Module):
def __call__(self, x, carry_gru1, carry_gru2):
carry_gru1, x = jnn.GRUCell()(carry_gru1, x)
carry_gru2, x = jnn.GRUCell()(carry_gru2, x)
x = jnn.Dense(4)(x)
x = x/jnp.linalg.norm(x)
return x, carry_gru1, carry_gru2
class RNN_torch(tnn.Module):
def __init__(self, batch_size, hidden_size, in_features, out_features):
self.gru = tnn.GRU(
self.dense = tnn.Linear(hidden_size, out_features)
self.init_carry = torch.zeros((2, batch_size, hidden_size))
def forward(self, X):
X, final_carry = self.gru(X, self.init_carry)
X = self.dense(X)
return X/X.norm(dim=-1).unsqueeze(-1).repeat((1, 1, 4))
rnn_jax = RNN_jax()
rnn_torch = RNN_torch(batch_size, hidden_size, in_features, out_features)
Xj = jax.random.normal(next(key), (seq_length, batch_size, in_features))
Yj = jax.random.normal(next(key), (seq_length, batch_size, out_features))
Xt = torch.from_numpy(np.array(Xj))
Yt = torch.from_numpy(np.array(Yj))
initial_carry_gru1 = jnp.zeros((batch_size, hidden_size))
initial_carry_gru2 = jnp.zeros((batch_size, hidden_size))
params = rnn_jax.init(next(key), Xj[0], initial_carry_gru1, initial_carry_gru2)
def forward(params, X):
carry_gru1, carry_gru2 = initial_carry_gru1, initial_carry_gru2
Yhat = []
for x in X: # x.shape = (batch_size, in_features)
yhat, carry_gru1, carry_gru2 = rnn_jax.apply(params, x, carry_gru1, carry_gru2)
Yhat.append(yhat) # y.shape = (batch_size, out_features)
#return jnp.concatenate(Y, axis=0)
jitted_forward = jax.jit(forward)
# uncompiled jax version
%time forward(params, Xj)
CPU times: user 7min 17s, sys: 8.18 s, total: 7min 25s Wall time: 7min 17s
# time for compiling
%time jitted_forward(params, Xj)
CPU times: user 8min 9s, sys: 4.46 s, total: 8min 13s Wall time: 8min 12s
# compiled jax version
%timeit jitted_forward(params, Xj)
The slowest run took 204.20 times longer than the fastest. This could mean that an intermediate result is being cached. 10000 loops, best of 5: 115 µs per loop
# torch version
%timeit lambda: rnn_torch(Xt)
10000000 loops, best of 5: 65.7 ns per loop
Why is my Jax-implementation so slow? What am i doing wrong?
Also, why is compiling taking so long? The sequence is not that long..
Thank you :)
The reason the JAX code compiles slowly is that during JIT compilation JAX unrolls loops. So in terms of XLA compilation, your function is actually very large: you call rnn_jax.apply()
1000 times, and compilation times tend to be roughly quadratic in the number of statements.
By contrast, your pytorch function uses no Python loops, and so under the hood it is relying on vectorized operations that run much faster.
Any time you use a for
loop over data in Python, a good bet is that your code will be slow: this is true whether you're using JAX, torch, numpy, pandas, etc. I'd suggest finding an approach to the problem in JAX that relies on vectorized operations rather than relying on slow Python looping.