Search code examples
pythonjitjaxtensorflow-xla

Why JAX is considering same list as different data structure depending on appending a new array inside function?


I am very new to JAX. Please excuse me if this something obvious or I am making some stupid mistake. I am trying to implement a function which does the following. All these functions will be called from other JIT-ed function. So, removing JIT may not be possible.

  1. get_elements function takes a JAX array( call it state (1D)). Looks at each element in it and calls a function get_condition.
  2. get_condition returns a tuple depending on the element at the given position of state. The tuple may be (1,0),(0,1) or (0,0)
  3. Here I want to call update_state only if the tuple received from get_conn is (0,1) or (1,0). In that case update_state_vec will get called and add a new vector of same length as state will get appended to the list.
  4. But, I couldn't make jax.lax.cond work here. So, I tried to call update_state for each case, but I want the list to remain unchanged if the codition is (0,0).
  5. In update_state_vec, no_update_state should return the same array
    that it receives withourt appending anything

Here, is the entire code:

import jax
import jax.numpy as jnp
from jax import random
from jax import lax
import copy
from copy import deepcopy

import numpy as np


def get_condition(state, x, y):
   L = (jnp.sqrt(len(jnp.asarray(state)))).astype(int)
   state = jnp.reshape(state, (L,L), order="F")
   s1 = state[x, y]

   branches = [lambda : (0,1), lambda : (1,0), lambda : (0,0)]
   conditions = jnp.array([s1==2, s1==4, True])
   result = lax.switch(jnp.argmax(conditions), branches)
   return tuple(x for x in result)




def update_state_vec(state, x, y, condition, list_scattered_states):
   L = (jnp.sqrt(len(state))).astype(int)   
   def update_state_4(list_scattered_states):
       state1 = jnp.array( jnp.reshape(deepcopy(state), (L, L), order="F"))
       state1 = state1.at[x, y].set(4)
       list_scattered_states.append(jnp.ravel(state1, order="F"))
       return list_scattered_states

   def update_state_2(list_scattered_states):
       state1 = jnp.array( jnp.reshape(deepcopy(state), (L, L), order="F"))
       state1 = state1.at[x, y].set(2)
       list_scattered_states.append(jnp.ravel(state1, order="F"))
       return list_scattered_states


   def no_update_state (list_scattered_states):
       #state1 = jnp.ravel(state, order="F")
       #list_scattered_states.append(jnp.ravel(state, order="F"))
       #This doesn't work---------------------------------
       return list_scattered_states



   conditions = jnp.array([condition == (1, 0), condition == (0, 1), condition == (0, 0)])
   print(conditions)
   branches = [update_state_4, update_state_2,no_update_state]

   return(lax.switch(jnp.argmax(conditions), branches, operand=list_scattered_states))
           


def get_elements(state):

   L = (jnp.sqrt(len(state))).astype(int)
   list_scattered_states = []
   for x in range(L):
       for y in range(L):
           condition=get_condition(state, x, y)
           print(condition)
           list_scattered_states = update_state_vec(state, x, y, condition, list_scattered_states)


   return list_scattered_states

We can take an example input as follows,

arr=jnp.asarray([2., 1., 3., 4., 1., 2., 3., 4., 4., 1., 2., 3., 4., 2., 1., 3.])
get_elements(arr)

I get an error message as below:

    print(conditions)
 41 branches = [update_state_4, update_state_2,no_update_state]
 ---> 43 return(lax.switch(jnp.argmax(conditions), branches, 
 operand=list_scattered_states))
 TypeError: branch 0 and 2 outputs must have same type structure, got PyTreeDef([*]) 
 and PyTreeDef([]).

So, the error is coming from the face that no_update_state is returning something that doesn't match with return type of update_state_4 or update_state_2. I am quite clueless at this point. Any help will be much appreciated.


Solution

  • The root of the issue here is that under transformations like jit, vmap, switch, etc. JAX requires the shape of outputs to be known statically, i.e. at compile time (see JAX sharp bits: dynamic shapes). In your case, the functions you are passing to switch return outputs of different shapes, and since jnp.argmax(conditions) is not known at compile time, there's no way for the compiler to know what memory to allocate for the result of this function.

    Since you're not JIT-compiling or otherwise transforming your code, the easiest way to address this would be to replace the lax.switch statement with this:

      if condition == (1, 0):
        list_scattered_states = update_state_4(list_scattered_states)
      elif condition == (0, 1):
        list_scattered_states = update_state_2(list_scattered_states)
      return list_scattered_states
    

    If you do want your function to be compatible with jit or other JAX transformations, you'll have to re-write the logic so that the size of list_scattered_states remains constant, e.g. by padding it to the expected size from the beginning.