Search code examples
pythonpython-3.xpandassortingbinary-search-tree

How to find k nearest values in a pandas data frame column to an input value x in O(logn)?


I have a data frame with two columns: 1) ID: random integer numbers which represent sample IDs, 2) A: float numbers

size_df1 = 1000
df1 = pd.DataFrame(np.random.random_sample((size_df1)), columns=list('A'))
df1['ID'] = random.sample(range(0, size_df1), size_df1)

Given an input like x=0.21, how to find 10 (or any other integer such as k) nearest values in df1['A'] to x, in log(n) where n is the number of rows in df1. Note, this should be done without replacement and each time that I find these 10 nearest values in df1['A'], I am supposed to remove these values or somehow mark them and not use them for the next x. Can this be solved in logn at all? Thanks


Solution

  • You can easily find the k smallest values with .nsmallest(), and the closest values are the ones with the smallest absolute difference:

    >>> (df1['A'] - 0.21).abs().nsmallest(10)
    969    0.000014
    889    0.000442
    779    0.003299
    259    0.003637
    843    0.003700
    84     0.003818
    651    0.004264
    403    0.004360
    648    0.004421
    543    0.005088
    Name: A, dtype: float64
    

    You can then reuse the indexes of this if you want to access the matching rows:

    >>> df1.loc[(df1['A'] - 0.21).abs().nsmallest(10).index]
                A   ID
    969  0.210014  237
    889  0.210442  225
    779  0.206701  127
    259  0.213637  883
    843  0.206300  330
    84   0.206182   17
    651  0.205736   64
    403  0.205640  388
    648  0.214421  964
    543  0.204912  616
    

    Note that the doc of nsmallest says:

    Faster than .sort_values().head(n) for small n relative to the size of the Series object.

    A word on complexity, since your values aren’t sorted:

    • the bare minimum complexity is O(n) if you want to find the 1 closest value
    • you could do a binary-search-like to get O(log(n)), but that requires sorting first − so it’s in fact O(n log(n)).

    Suppose your dataframe is sorted on A:

    >>> df1.sort_values('A', inplace=True)
    

    Then we can try to use the sorted search function, which returns the row number (not index value):

    >>> df1['A'].searchsorted(0.21)
    197
    

    This means we can use that to find the k closest candidate and then use our previous method on this 2k dataframe:

    def find_closest(df, val, k):
        return df.loc[df['A'].sub(val).abs().nsmallest(k).index]
    
    def find_closest_sorted(df, val, k):
        closest = df['A'].searchsorted(val)
        if closest < k:
            return find_closest(df.iloc[:closest + k], val, k)
    
        return find_closest(df.iloc[closest - k:closest + k], val, k)
    
    >>> find_closest_sorted(df1, 0.21, 10)
                A   ID
    969  0.210014  237
    889  0.210442  225
    779  0.206701  127
    259  0.213637  883
    843  0.206300  330
    84   0.206182   17
    651  0.205736   64
    403  0.205640  388
    648  0.214421  964
    543  0.204912  616
    

    The complexity should be here:

    • O(n log(n)) for sorting (which can be amortized over many lookups)
    • O(log(n)) for the sorted search
    • O(k) for the final step.