Search code examples
pytorchlinear-algebrajitjaxtensorflow-xla

How can I test if a jitted Jax function creates new tensor or a view?


I have a basic code like this:

@jit
def concat_permute(indices, in1, in2):
    tensor = jnp.concatenate([jnp.atleast_1d(in1), jnp.atleast_1d(in2)])
    return tensor[indices]

Here is my test tensors:

key = jax.random.PRNGKey(758493)
in1 = tens = jax.random.uniform(key, shape=(15,5,3))
in2 = tens = jax.random.uniform(key, shape=(10,5,3))
indices = jax.random.choice(key, 25, (25,), replace=False)

And here is the Jaxpr of the function:

{ lambda ; a:i32[25] b:f32[15,5,3] c:f32[10,5,3]. let
    d:f32[25,5,3] = xla_call[
      call_jaxpr={ lambda ; e:i32[25] f:f32[15,5,3] g:f32[10,5,3]. let
          h:f32[15,5,3] = xla_call[
            call_jaxpr={ lambda ; i:f32[15,5,3]. let  in (i,) }
            name=atleast_1d
          ] f
          j:f32[10,5,3] = xla_call[
            call_jaxpr={ lambda ; k:f32[10,5,3]. let  in (k,) }
            name=atleast_1d
          ] g
          l:f32[25,5,3] = concatenate[dimension=0] h j
          m:bool[25] = lt e 0
          n:i32[25] = add e 25
          o:i32[25] = select_n m e n
          p:i32[25,1] = broadcast_in_dim[
            broadcast_dimensions=(0,)
            shape=(25, 1)
          ] o
          q:f32[25,5,3] = gather[
            dimension_numbers=GatherDimensionNumbers(offset_dims=(1, 2), collapsed_slice_dims=(0,), start_index_map=(0,))
            fill_value=None
            indices_are_sorted=False
            mode=GatherScatterMode.PROMISE_IN_BOUNDS
            slice_sizes=(1, 5, 3)
            unique_indices=False
          ] l p
        in (q,) }
      name=concat_permute
    ] a b c
  in (d,) }

It seems it creates a new tensor using my permutation array but I'm not sure. Is there a more clear way to see if this opeeration is made by creating new tensor or not?

I tried "jax.make_jaxpr" and see the results but not sure about the problem.


Solution

  • The short answer is, no the output of your function will not share memory with the array allocated for tensor.

    In XLA, an array is represented by a uniformly-strided buffer, and when you select random values from an array, the result cannot in general be constructed via uniform-striding over a view of the input buffer.