I am new to JAX and trying to learn use it for running some code on a GPU. In my example I want to search for regular grids in a point cloud (for indexing X-ray diffraction data).
With test_mats[4_000_000,3,3]
the memory usage seems to be 15 MB. But with test_mats[5_000_000,3,3]
I get an error about it wanting to allocate 19 GB.
I can't tell whether this is a glitch in JAX, or because I am doing something wrong. My example code and output are below. I guess the problem is that it wants to create a temporary array of (N, 3, gvec.shape[1]) before doing the reduction, but I don't know how to see the memory profile for what happens inside the jitted/vmapped function.
import sys
import os
import jax
import jax.random
import jax.profiler
print('jax.version.__version__',jax.version.__version__)
import scipy.spatial.transform
import numpy as np
# (3,N) integer grid spot positions
hkls = np.mgrid[-3:4, -3:4, -3:4].reshape(3,-1)
Umat = scipy.spatial.transform.Rotation.random( 10, random_state=42 ).as_matrix()
a0 = 10.13
gvec = np.swapaxes( Umat.dot(hkls)/a0, 0, 1 ).reshape(3,-1)
def count_indexed_peaks_hkl( ubi, gve, tol ):
""" See how many gve this ubi can account for """
hkl_real = ubi.dot( gve )
hkl_int = jax.numpy.round( hkl_real )
drlv2 = ((hkl_real - hkl_int)**2).sum(axis=0)
npks = jax.numpy.where( drlv2 < tol*tol, 1, 0 ).sum()
return npks
def testsize( N ):
print("Testing size",N)
jfunc = jax.vmap( jax.jit(count_indexed_peaks_hkl), in_axes=(0,None,None))
key = jax.random.PRNGKey(0)
test_mats = jax.random.orthogonal(key, 3, (N,) )*a0
dev_gvec = jax.device_put( gvec )
scores = jfunc( test_mats, gvec, 0.01 )
jax.profiler.save_device_memory_profile(f"memory_{N}.prof")
os.system(f"~/go/bin/pprof -top {sys.executable} memory_{N}.prof")
testsize(400000)
testsize(500000)
Output is:
gpu4-03:~/Notebooks/JAXFits % python mem.py
jax.version.__version__ 0.4.16
Testing size 400000
File: python
Type: space
Showing nodes accounting for 15.26MB, 99.44% of 15.35MB total
Dropped 25 nodes (cum <= 0.08MB)
flat flat% sum% cum cum%
15.26MB 99.44% 99.44% 15.26MB 99.44% __call__
0 0% 99.44% 15.35MB 100% [python]
0 0% 99.44% 1.53MB 10.00% _pjit_batcher
0 0% 99.44% 15.30MB 99.70% _pjit_call_impl
0 0% 99.44% 15.30MB 99.70% _pjit_call_impl_python
0 0% 99.44% 15.30MB 99.70% _python_pjit_helper
0 0% 99.44% 15.35MB 100% bind
0 0% 99.44% 15.35MB 100% bind_with_trace
0 0% 99.44% 15.30MB 99.70% cache_miss
0 0% 99.44% 15.30MB 99.70% call_impl_cache_miss
0 0% 99.44% 1.53MB 10.00% call_wrapped
0 0% 99.44% 13.74MB 89.51% deferring_binary_op
0 0% 99.44% 15.35MB 100% process_primitive
0 0% 99.44% 15.30MB 99.70% reraise_with_filtered_traceback
0 0% 99.44% 15.35MB 100% testsize
0 0% 99.44% 1.53MB 10.00% vmap_f
0 0% 99.44% 15.31MB 99.74% wrapper
Testing size 500000
2023-12-14 10:26:23.630474: W external/tsl/tsl/framework/bfc_allocator.cc:296] Allocator
(GPU_0_bfc) ran out of memory trying to allocate 19.18GiB with freed_by_count=0. The caller
indicates that this is not a failure, but this may mean that there could be performance
gains if more memory were available.
Traceback (most recent call last):
File "~/Notebooks/JAXFits/mem.py", line 38, in <module>
testsize(500000)
File "~/Notebooks/JAXFits/mem.py", line 33, in testsize
scores = jfunc( test_mats, gvec, 0.01 )
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to
allocate 20596777216 bytes.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following
exception. Set JAX_TRACEBACK_FILTERING=off to include these.```
The vmapped function is attempting to create an intermediate array of shape [N, 3, 3430]
. For N=400_000
, with float32
this amounts to 15GB, and for N=500_000
this amounts to 19GB.
Your best option in this situation is probably to split your computation into sequentially-executed batches using lax.map
or similar. Unfortunately there's not currently any automatic way to do that kind of chunked vmao, but there is a relevant feature request at https://github.com/google/jax/issues/11319, and there are some useful suggestions in that thread.