Search code examples
python-3.xsparse-matrixknneuclidean-distancecsr

Find euclidean distance between rows of two huge CSR matrices


I have two sparse martrices, A and B. A is 120000*5000 and B is 30000*5000. I need to find the euclidean distances between each row in B with all rows of A and then find the 5 rows in A with the lowest distance to the selected row in B. As it is a very big data I am using CSR otherwise I get memory error. It is clear that for each row in A it calculates (x_b - x_a)^2 5000 times and sums them and then get a sqrt. This process is taking a very very long time, like 11 days! Is there any way I can do this more efficiently? I just need the 5 rows with the lowest distance to each row in B.

I am implementing K-Nearest Neighbours and A is my training set and B is my test set.


Solution

  • Well - I don't know if you could 'vectorize' that code, so that it would run in native code instead of Python. The trick to speed-up numpy and scipy is always getting that.

    If you can run that code in native code in a 1GHz CPU, with 1 FP instruction for clock cicle, you'd get it done in a little under 10 hours. (5000 * 2 * 30000 * 120000) / 1024 ** 3

    Raise that to 1.5Ghz x 2 CPU physical cores x 4 way SIMD instructions with multiply + acummulate (Intel AVX extensions, available in most CPUs) and you could get that number crunching down to one hour, at 2 x 100% on a modest core i5 machinne. But that would require full SIMD optimization in native code - far from a trivial task (although, if you decide to go this path, further questions on S.O. could get help from people either to wet their hands in SIMD coding :-) ) - interfacing this code in C with Scipy is not hard using cython, for example (you only need that part to get it to the above 10 hour figure)

    Now... as for algorithm optimization, and keeping things Python :-)
    Fact is, you don't need to fully calculate all distances from rows in A - you just need to keep a sorted list of the 5 lower rows - and any time the cumulation of a sum of squares get larger than the 5th nearest row (so far), you just abort the calculation for that row.

    You could use Python' heapq operations for that:

    import heapq
    import math
    
    def get_closer_rows(b_row, a):
        result = [(float("+inf"), None)  * 5]
        for i, a_row in enumerate(a):
            distance_sq = 0
            count = 0
            for element_a, element_b in zip(a_row, b_row):
                distance_sq += element_a * element_b
                if not count % 64 and distance_sq > result[4][0]:
                    break
                count += 1
            else:
                heapq.heappush(result, (distance, i))
                result[:] = result[:5]
        return [math.sqrt(r) for r in result]
    
    closer_rows_to_b = []
    for row in b:
        closer_rows_to_b.append(get_closer_rows(row, a))
    

    Note the auxiliar "count" to avoid the expensive retrieving and comparison of values for all multiplications. Now, if you can run this code using pypy instead of regular Python, I believe it could get full benefit of JITting, and you could get a noticeable improvement over your times if you are running the code in pure Python (i.e.: non numpy/scipy vectorized code).