I am iterating through each head and applying either f1 or f2 function depending on the value of the parameter self.alpha.
I only want to evaluate either function f1 or f2 not both and then select output of one based on conditional.
def f1 (x):
print('f1')
return x/x.shape[2]
def f2 (x):
print('f2')
temp = nn.relu(x)
return temp/(jnp.sum(temp,axis=-1,keepdims=True) + 1e-5)
def choose_attention(alpha, x):
return jax.lax.cond(alpha[0, 0, 0,0],f2,f1,operand=x)
results = []
func = [f1,f2]
for i in range(self.alpha.shape[1]):
print(i)
alpha_i = self.alpha[:, i:i+1, :, :]
x_i = attn_weights[:, i:i+1, :, :]
result_i = jax.lax.switch(self.alpha[0,0,0,0].astype(int),func,x_i)
results.append(result_i)
final_result = jnp.concatenate(results, axis=1)
My print statements read like 0 f1 f2 1 2 3 4 5 6 7 8 9 10 11
jax.lax.switch
does what you want: it chooses between two different functions based on a runtime value. Your use of print
statements is misleading you: Python print
runs at trace-time rather than runtime, and all code will be traced even if it is not eventually executed.
For some background on how to think about the execution model of JAX programs, I would suggest How to think in JAX.
Side note: for better performance, I would also suggest avoiding using Python for
loops to loop through array values, and instead express your algorithm using either Numpy-style explicit vectorization, or using jax.vmap
to automatically vectorize your code.