My Task:
I'm trying to calculate the pair-wise distance between every two samples in two big tensors (for k-Nearest-Neighbours), That is - given tensor test
with shape (b1,c,h,w)
and tensor train
with shape (b2,c,h,w)
, I need || test[i]-train[j] ||
for every i
,j
. (where both test[i]
and train[j]
have shape (c,h,w)
, as those are sampes in the batch).
The Problem
both train
and test
are very big, so I can't fit them into RAM
My current solution
For a start, I did not construct these tensors in one go - As I build them, I split the data Tensor and save them separately to memory, so I end up with files {Test\test_1,...,Test\test_n}
and {Train\train_1,...,Train\train_m}
.
Then, I load in a nested for
loop every Test\test_i
and Train\train_j
, calculate the current distance, and save it.
This semi-pseudo-code might explain
test_files = [f'Test\test_{i}' for i in range(n)]
train_files = [f'Train\train_{j}' for j in range(m)]
dist = lambda t1,t2: torch.cdist(t1.flatten(1), t2.flatten(1))
all_distances = []
for test_i in test_files:
test_i = torch.load(test_i) # shape (c,h,w)
dist_of_i_from_all_j = torch.Tensor([])
for train_j in train_files:
train_j = torch.load(train_j) # shape (c,h,w)
dist_of_i_from_all_j = torch.cat((dist_of_i_from_all_j, dist(test_i,train_j))
all_distances.append(dist_of_i_from_all_j)
# and now I can take the k-smallest from all_distances
What I thought might work
I came across FAISS repository, in which they explain that this process can be sped up (maybe?) using their solutions, though I'm not quite sure how. Regardless, any approach would help!
Did you check the FAISS documentation?
If what you need is the L2 norm (torch.cidst
uses p=2
as default parameter) then it is quite straightforward. Code below is an adaptation of the FAISS docs to your example:
import faiss
import numpy as np
d = 64 # dimension
nb = 100000 # database size
nq = 10000 # nb of queries
np.random.seed(1234) # make reproducible
x_test = np.random.random((nb, d)).astype('float32')
x_test[:, 0] += np.arange(nb) / 1000.
x_train = np.random.random((nq, d)).astype('float32')
x_train[:, 0] += np.arange(nq) / 1000.
index = faiss.IndexFlatL2(d) # build the index
print(index.is_trained)
index.add(x_test) # add vectors to the index
print(index.ntotal)
k= 100 # take the 100 closest neighbors
D, I = index.search(x_train, k) # actual search
print(I[:5]) # neighbors of the 100 first queries
print(I[-5:]) # neighbors of the 100 last queries