Search code examples
pythontorchknnnearest-neighborfaiss

How to perform operations on very big torch tensors without splitting them


My Task:

I'm trying to calculate the pair-wise distance between every two samples in two big tensors (for k-Nearest-Neighbours), That is - given tensor test with shape (b1,c,h,w) and tensor train with shape (b2,c,h,w), I need || test[i]-train[j] || for every i,j. (where both test[i] and train[j] have shape (c,h,w), as those are sampes in the batch).

The Problem

both train and test are very big, so I can't fit them into RAM

My current solution

For a start, I did not construct these tensors in one go - As I build them, I split the data Tensor and save them separately to memory, so I end up with files {Test\test_1,...,Test\test_n} and {Train\train_1,...,Train\train_m}. Then, I load in a nested for loop every Test\test_i and Train\train_j, calculate the current distance, and save it.

This semi-pseudo-code might explain

test_files = [f'Test\test_{i}' for i in range(n)]
train_files = [f'Train\train_{j}' for j in range(m)]
dist = lambda t1,t2: torch.cdist(t1.flatten(1), t2.flatten(1)) 
all_distances = []
for test_i in test_files:
    test_i = torch.load(test_i) # shape (c,h,w)
    dist_of_i_from_all_j = torch.Tensor([])
    for train_j in train_files:
        train_j = torch.load(train_j) # shape (c,h,w)
        dist_of_i_from_all_j = torch.cat((dist_of_i_from_all_j, dist(test_i,train_j))
    all_distances.append(dist_of_i_from_all_j)
# and now I can take the k-smallest from all_distances

What I thought might work

I came across FAISS repository, in which they explain that this process can be sped up (maybe?) using their solutions, though I'm not quite sure how. Regardless, any approach would help!


Solution

  • Did you check the FAISS documentation?

    If what you need is the L2 norm (torch.cidst uses p=2 as default parameter) then it is quite straightforward. Code below is an adaptation of the FAISS docs to your example:

    import faiss
    import numpy as np
    d = 64                           # dimension
    nb = 100000                      # database size
    nq = 10000                       # nb of queries
    np.random.seed(1234)             # make reproducible
    x_test = np.random.random((nb, d)).astype('float32')
    x_test[:, 0] += np.arange(nb) / 1000.
    x_train = np.random.random((nq, d)).astype('float32')
    x_train[:, 0] += np.arange(nq) / 1000.
    
    index = faiss.IndexFlatL2(d)   # build the index
    print(index.is_trained)
    index.add(x_test)                  # add vectors to the index
    print(index.ntotal)
    
    k= 100 # take the 100 closest neighbors
    D, I = index.search(x_train, k)     # actual search
    print(I[:5])                   # neighbors of the 100 first queries
    print(I[-5:])                  # neighbors of the 100 last queries