I am running a tutorial on muatrix multiplication with JAX with data sharded in different ways across multiple GPUs. I found not only the computation time is different for different way of sharding, the results are also slightly different.
Here are my observations:
Can anyone help me understand these two observations? One additional question is: if the way of sharding is so important, will mainstream machine learning algorithms have ways to deal with it automatically so that different way of sharding won't give different models?
Method 0: Perform the matrix multiplication on the same GPU device (just use 1 device).
Method 1:
x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))
y = jax.device_put(x, sharding.reshape(4, 2).replicate(1))
z = jax.device_put(x, sharding.reshape(4, 2).replicate(0))
print('lhs sharding:')
jax.debug.visualize_array_sharding(y)
print('rhs sharding:')
jax.debug.visualize_array_sharding(z)
w = jnp.dot(y, z)
Method 2:
x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))
y = jax.device_put(x, sharding.reshape(4, 2))
z = jax.device_put(x, sharding.reshape(4, 2))
print('lhs sharding:')
jax.debug.visualize_array_sharding(y)
print('rhs sharding:')
jax.debug.visualize_array_sharding(z)
w = jnp.dot(y, z)
Regarding your observation of differing results: this is to be expected with floating point operations. Every time you do a floating point operation, it accumulates a small amount of error, and when you express the "same" floating point computation in different ways, the errors accumulate differently.
Here's an example of this using NumPy:
import numpy as np
x = np.random.rand(10000).astype('float32')
x_reversed = x[::-1]
np.dot(x, x) == np.dot(x_reversed, x_reversed)
# False
If we were dealing with real numbers, we'd expect these two to be identical. But because we're representing our computation with floating point values, the two approaches return slightly different results. This is similar to the situation in your question: different sharding layouts lead to different ordering of the dot product accumulations, which leads to slightly different results.
Regarding your observation about computation speed: the results seem reasonable. Method 0 is the slowest because it only uses a single device, and method 1 is faster than method 2 because pre-replicating the data means that less data movement is required during the actual computation.