I have a list holding many lists of the same structure (Usually, there are much more than two sub-lists inside the list, the example shows two lists for the sake of simplicity). I would like to create the sum or product over all sub-lists so that the resulting list has the same structure as one of the sub-lists. So far I tried the following using the tree_reduce
method but I get errors that I don't understand.
I could need some guidance on how to use tree_reduce() in such a case.
import jax
import jax.numpy as jnp
list_1 = [
[jnp.asarray([1]), jnp.asarray([2, 3])],
[jnp.asarray([4]), jnp.asarray([5, 6])],
]
list_2 = [
[jnp.asarray([7]), jnp.asarray([8, 9])],
[jnp.asarray([10]), jnp.asarray([11, 12])],
]
list_of_lists = [list_1, list_2]
reduced = jax.tree_util.tree_reduce(lambda x, y: x + y, list_of_lists, 0, is_leaf=True)
# Expected
# reduced = [
# [jnp.asarray([8]), jnp.asarray([10, 12])],
# [jnp.asarray([14]), jnp.asarray([16, 18])],
# ]
You can do this with tree_map
of a sum
over the splatted list:
reduced = jax.tree_util.tree_map(lambda *args: sum(args), *list_of_lists)
print(reduced)
[[Array([8], dtype=int32), Array([10, 12], dtype=int32)],
[Array([14], dtype=int32), Array([16, 18], dtype=int32)]]