Why does this euclidean distance calculation method explodes RAM usage?

I'm studying the KNN algorithm to classify images using some material from a 2017 Stanford course. We're given a dataset consisting of many images, later those sets are represented as 2D numpy arrays, and we're supposed to write functions that calculate distances between those images. More specifically, given a 2D array of the test images and a 2D array of the training images, I'm asked to write a L_2 distance function, which takes those two sets as inputs and returns a distance matrix, where every row i represents a test image and every column j represents a training image.

The exercise also asked me to do it without any loops and without using np.abs function. So I gave it a try and tried:

def compute_distances_no_loops(self, X):
    Compute the distance between each test point in X and each training point
    in self.X_train using no explicit loops.

    Input / Output: Same as compute_distances_two_loops
    num_test = X.shape[0]
    num_train = self.X_train.shape[0]
    dists = np.zeros((num_test, num_train))
    all_test_subs_sq = (X[:, np.newaxis] - self.X_train)**2
    dists = np.sqrt(np.sum(all_test_subs_sq), axis = 2)
    return dists

Apparently that makes Google's Colab environment crash in 6 seconds due to allocating about 60 GB of RAM. I guess I should clarify the training set X_train has a shape of (5000, 3072), and the test set X has shape (500, 3072). I am not sure what happens here that is so RAM intensive, but then again I'm not the smartest guy to figure out space complexity.

I googled a bit and found out a solution that works without the need for a NASA computer, it uses the sum of the squares formula:

    dists = np.reshape(np.sum(X**2, axis=1), [num_test,1]) + np.sum(self.X_train**2, axis=1)\
- 2 * np.matmul(X, self.X_train.T)
    dists = np.sqrt(dists)

I'm also not really sure why doesn't this solution explode like mine did. I'd really appreciate any insight here, thank you very much for reading.


  • In the compute_distances_no_loops() function the intermediate array all_test_subs_sq has the shape (500, 3072, 5000), so it consists of 500 * 3072 * 5000 = 7,680,000,000 elements. Assuming that the dtype of X and X_train is float64, each element weights 8 bytes, so the total size of the array is 61,440,000,000 bytes i.e. about 60 GB.

    The other solution you included avoids this problem since it does not create such a large intermediate array. The shape of np.reshape(np.sum(X**2, axis=1), [num_test,1]) is (500, 1) and the shape of np.sum(self.X_train**2, axis=1) is (5000,). When you add them you obtain an array of the shape (500, 5000). np.matmul(X, self.X_train.T) also produces an array of the same shape.