Search code examples
pythonjax

How to understand and debug memory usage with JAX?


I am new to JAX and trying to learn use it for running some code on a GPU. In my example I want to search for regular grids in a point cloud (for indexing X-ray diffraction data).

With test_mats[4_000_000,3,3] the memory usage seems to be 15 MB. But with test_mats[5_000_000,3,3] I get an error about it wanting to allocate 19 GB.

I can't tell whether this is a glitch in JAX, or because I am doing something wrong. My example code and output are below. I guess the problem is that it wants to create a temporary array of (N, 3, gvec.shape[1]) before doing the reduction, but I don't know how to see the memory profile for what happens inside the jitted/vmapped function.

import sys
import os
import jax
import jax.random
import jax.profiler

print('jax.version.__version__',jax.version.__version__)

import scipy.spatial.transform
import numpy as np

# (3,N) integer grid spot positions
hkls = np.mgrid[-3:4, -3:4, -3:4].reshape(3,-1)

Umat = scipy.spatial.transform.Rotation.random( 10, random_state=42 ).as_matrix()
a0 = 10.13
gvec = np.swapaxes( Umat.dot(hkls)/a0, 0, 1 ).reshape(3,-1)

def count_indexed_peaks_hkl( ubi, gve, tol ):
    """ See how many gve this ubi can account for """
    hkl_real = ubi.dot( gve )
    hkl_int = jax.numpy.round( hkl_real )
    drlv2 = ((hkl_real - hkl_int)**2).sum(axis=0)
    npks = jax.numpy.where( drlv2 < tol*tol, 1, 0 ).sum()
    return npks

def testsize( N ):
    print("Testing size",N)
    jfunc = jax.vmap( jax.jit(count_indexed_peaks_hkl), in_axes=(0,None,None))
    key = jax.random.PRNGKey(0)
    test_mats = jax.random.orthogonal(key, 3, (N,) )*a0
    dev_gvec = jax.device_put( gvec )
    scores = jfunc( test_mats, gvec, 0.01 )
    jax.profiler.save_device_memory_profile(f"memory_{N}.prof")
    os.system(f"~/go/bin/pprof -top {sys.executable} memory_{N}.prof")

testsize(400000)
testsize(500000)

Output is:

gpu4-03:~/Notebooks/JAXFits % python mem.py 
jax.version.__version__ 0.4.16
Testing size 400000
File: python
Type: space
Showing nodes accounting for 15.26MB, 99.44% of 15.35MB total
Dropped 25 nodes (cum <= 0.08MB)
      flat  flat%   sum%        cum   cum%
   15.26MB 99.44% 99.44%    15.26MB 99.44%  __call__
         0     0% 99.44%    15.35MB   100%  [python]
         0     0% 99.44%     1.53MB 10.00%  _pjit_batcher
         0     0% 99.44%    15.30MB 99.70%  _pjit_call_impl
         0     0% 99.44%    15.30MB 99.70%  _pjit_call_impl_python
         0     0% 99.44%    15.30MB 99.70%  _python_pjit_helper
         0     0% 99.44%    15.35MB   100%  bind
         0     0% 99.44%    15.35MB   100%  bind_with_trace
         0     0% 99.44%    15.30MB 99.70%  cache_miss
         0     0% 99.44%    15.30MB 99.70%  call_impl_cache_miss
         0     0% 99.44%     1.53MB 10.00%  call_wrapped
         0     0% 99.44%    13.74MB 89.51%  deferring_binary_op
         0     0% 99.44%    15.35MB   100%  process_primitive
         0     0% 99.44%    15.30MB 99.70%  reraise_with_filtered_traceback
         0     0% 99.44%    15.35MB   100%  testsize
         0     0% 99.44%     1.53MB 10.00%  vmap_f
         0     0% 99.44%    15.31MB 99.74%  wrapper
Testing size 500000
2023-12-14 10:26:23.630474: W external/tsl/tsl/framework/bfc_allocator.cc:296] Allocator
(GPU_0_bfc) ran out of memory trying to allocate 19.18GiB with freed_by_count=0. The caller
indicates that this is not a failure, but this may mean that there could be performance 
gains if more memory were available.
Traceback (most recent call last):
  File "~/Notebooks/JAXFits/mem.py", line 38, in <module>
    testsize(500000)
  File "~/Notebooks/JAXFits/mem.py", line 33, in testsize
    scores = jfunc( test_mats, gvec, 0.01 )
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to 
allocate 20596777216 bytes.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following 
exception. Set JAX_TRACEBACK_FILTERING=off to include these.```

Solution

  • The vmapped function is attempting to create an intermediate array of shape [N, 3, 3430]. For N=400_000, with float32 this amounts to 15GB, and for N=500_000 this amounts to 19GB.

    Your best option in this situation is probably to split your computation into sequentially-executed batches using lax.map or similar. Unfortunately there's not currently any automatic way to do that kind of chunked vmao, but there is a relevant feature request at https://github.com/google/jax/issues/11319, and there are some useful suggestions in that thread.