Search code examples
pythonjax

Reduce list of lists in JAX


I have a list holding many lists of the same structure (Usually, there are much more than two sub-lists inside the list, the example shows two lists for the sake of simplicity). I would like to create the sum or product over all sub-lists so that the resulting list has the same structure as one of the sub-lists. So far I tried the following using the tree_reduce method but I get errors that I don't understand.

I could need some guidance on how to use tree_reduce() in such a case.

import jax
import jax.numpy as jnp

list_1 = [
    [jnp.asarray([1]), jnp.asarray([2, 3])],
    [jnp.asarray([4]), jnp.asarray([5, 6])],
]

list_2 = [
    [jnp.asarray([7]), jnp.asarray([8, 9])],
    [jnp.asarray([10]), jnp.asarray([11, 12])],
]
    
list_of_lists = [list_1, list_2]
   
reduced = jax.tree_util.tree_reduce(lambda x, y: x + y, list_of_lists, 0, is_leaf=True)
    
# Expected
# reduced = [
#     [jnp.asarray([8]), jnp.asarray([10, 12])],
#     [jnp.asarray([14]), jnp.asarray([16, 18])],
# ]

Solution

  • You can do this with tree_map of a sum over the splatted list:

    reduced = jax.tree_util.tree_map(lambda *args: sum(args), *list_of_lists)
    print(reduced)
    
    [[Array([8], dtype=int32), Array([10, 12], dtype=int32)],
     [Array([14], dtype=int32), Array([16, 18], dtype=int32)]]