Search code examples
matlabmachine-learningclassificationknn

Finding K-nearest neighbors and its implementation


I am working on classifying simple data using KNN with Euclidean distance. I have seen an example on what I would like to do that is done with the MATLAB knnsearch function as shown below:

load fisheriris 
x = meas(:,3:4);
gscatter(x(:,1),x(:,2),species)
newpoint = [5 1.45];
[n,d] = knnsearch(x,newpoint,'k',10);
line(x(n,1),x(n,2),'color',[.5 .5 .5],'marker','o','linestyle','none','markersize',10)

The above code takes a new point i.e. [5 1.45] and finds the 10 closest values to the new point. Can anyone please show me a MATLAB algorithm with a detailed explanation of what the knnsearch function does? Is there any other way to do this?


Solution

  • The basis of the K-Nearest Neighbour (KNN) algorithm is that you have a data matrix that consists of N rows and M columns where N is the number of data points that we have, while M is the dimensionality of each data point. For example, if we placed Cartesian co-ordinates inside a data matrix, this is usually a N x 2 or a N x 3 matrix. With this data matrix, you provide a query point and you search for the closest k points within this data matrix that are the closest to this query point.

    We usually use the Euclidean distance between the query and the rest of your points in your data matrix to calculate our distances. However, other distances like the L1 or the City-Block / Manhattan distance are also used. After this operation, you will have N Euclidean or Manhattan distances which symbolize the distances between the query with each corresponding point in the data set. Once you find these, you simply search for the k nearest points to the query by sorting the distances in ascending order and retrieving those k points that have the smallest distance between your data set and the query.

    Supposing your data matrix was stored in x, and newpoint is a sample point where it has M columns (i.e. 1 x M), this is the general procedure you would follow in point form:

    1. Find the Euclidean or Manhattan distance between newpoint and every point in x.
    2. Sort these distances in ascending order.
    3. Return the k data points in x that are closest to newpoint.

    Let's do each step slowly.


    Step #1

    One way that someone may do this is perhaps in a for loop like so:

    N = size(x,1);
    dists = zeros(N,1);
    for idx = 1 : N
        dists(idx) = sqrt(sum((x(idx,:) - newpoint).^2));
    end
    

    If you wanted to implement the Manhattan distance, this would simply be:

    N = size(x,1);
    dists = zeros(N,1);
    for idx = 1 : N
        dists(idx) = sum(abs(x(idx,:) - newpoint));
    end
    

    dists would be a N element vector that contains the distances between each data point in x and newpoint. We do an element-by-element subtraction between newpoint and a data point in x, square the differences, then sum them all together. This sum is then square rooted, which completes the Euclidean distance. For the Manhattan distance, you would perform an element by element subtraction, take the absolute values, then sum all of the components together. This is probably the most simplest of the implementations to understand, but it could possibly be the most inefficient... especially for larger sized data sets and larger dimensionality of your data.

    Another possible solution would be to replicate newpoint and make this matrix the same size as x, then doing an element-by-element subtraction of this matrix, then summing over all of the columns for each row and doing the square root. Therefore, we can do something like this:

    N = size(x, 1);
    dists = sqrt(sum((x - repmat(newpoint, N, 1)).^2, 2));
    

    For the Manhattan distance, you would do:

    N = size(x, 1);
    dists = sum(abs(x - repmat(newpoint, N, 1)), 2);
    

    repmat takes a matrix or vector and repeats them a certain amount of times in a given direction. In our case, we want to take our newpoint vector, and stack this N times on top of each other to create a N x M matrix, where each row is M elements long. We subtract these two matrices together, then square each component. Once we do this, we sum over all of the columns for each row and finally take the square root of all result. For the Manhattan distance, we do the subtraction, take the absolute value and then sum.

    However, the most efficient way to do this in my opinion would be to use bsxfun. This essentially does the replication that we talked about under the hood with a single function call. Therefore, the code would simply be this:

    dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2));
    

    To me this looks much cleaner and to the point. For the Manhattan distance, you would do:

    dists = sum(abs(bsxfun(@minus, x, newpoint)), 2);
    

    Step #2

    Now that we have our distances, we simply sort them. We can use sort to sort our distances:

    [d,ind] = sort(dists);
    

    d would contain the distances sorted in ascending order, while ind tells you for each value in the unsorted array where it appears in the sorted result. We need to use ind, extract the first k elements of this vector, then use ind to index into our x data matrix to return those points that were the closest to newpoint.

    Step #3

    The final step is to now return those k data points that are closest to newpoint. We can do this very simply by:

    ind_closest = ind(1:k);
    x_closest = x(ind_closest,:);
    

    ind_closest should contain the indices in the original data matrix x that are the closest to newpoint. Specifically, ind_closest contains which rows you need to sample from in x to obtain the closest points to newpoint. x_closest will contain those actual data points.


    For your copying and pasting pleasure, this is what the code looks like:

    dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2));
    %// Or do this for Manhattan
    % dists = sum(abs(bsxfun(@minus, x, newpoint)), 2);
    [d,ind] = sort(dists);
    ind_closest = ind(1:k);
    x_closest = x(ind_closest,:);
    

    Running through your example, let's see our code in action:

    load fisheriris 
    x = meas(:,3:4);
    newpoint = [5 1.45];
    k = 10;
    
    %// Use Euclidean
    dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2));
    [d,ind] = sort(dists);
    ind_closest = ind(1:k);
    x_closest = x(ind_closest,:);
    

    By inspecting ind_closest and x_closest, this is what we get:

    >> ind_closest
    
    ind_closest =
    
       120
        53
        73
       134
        84
        77
        78
        51
        64
        87
    
    >> x_closest
    
    x_closest =
    
        5.0000    1.5000
        4.9000    1.5000
        4.9000    1.5000
        5.1000    1.5000
        5.1000    1.6000
        4.8000    1.4000
        5.0000    1.7000
        4.7000    1.4000
        4.7000    1.4000
        4.7000    1.5000
    

    If you ran knnsearch, you will see that your variable n matches up with ind_closest. However, the variable d returns the distances from newpoint to each point x, not the actual data points themselves. If you want the actual distances, simply do the following after the code I wrote:

    dist_sorted = d(1:k);
    

    Note that the above answer uses only one query point in a batch of N examples. Very frequently KNN is used on multiple examples simultaneously. Supposing that we have Q query points that we want to test in the KNN. This would result in a k x M x Q matrix where for each example or each slice, we return the k closest points with a dimensionality of M. Alternatively, we can return the IDs of the k closest points thus resulting in a Q x k matrix. Let's compute both.

    A naive way to do this would be to apply the above code in a loop and loop over every example.

    Something like this would work where we allocate a Q x k matrix and apply the bsxfun based approach to set each row of the output matrix to the k closest points in the dataset, where we will use the Fisher Iris dataset just like what we had before. We'll also keep the same dimensionality as we did in the previous example and I'll use four examples, so Q = 4 and M = 2:

    %// Load the data and create the query points
    load fisheriris;
    x = meas(:,3:4);
    newpoints = [5 1.45; 7 2; 4 2.5; 2 3.5];
    
    %// Define k and the output matrices
    Q = size(newpoints, 1);
    M = size(x, 2);
    k = 10;
    x_closest = zeros(k, M, Q);
    ind_closest = zeros(Q, k);
    
    %// Loop through each point and do logic as seen above:
    for ii = 1 : Q
        %// Get the point
        newpoint = newpoints(ii, :);
    
        %// Use Euclidean
        dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2));
        [d,ind] = sort(dists);
    
        %// New - Output the IDs of the match as well as the points themselves
        ind_closest(ii, :) = ind(1 : k).';
        x_closest(:, :, ii) = x(ind_closest(ii, :), :);
    end
    

    Though this is very nice, we can do even better. There is a way to efficiently compute the squared Euclidean distance between two sets of vectors. I'll leave it as an exercise if you want to do this with the Manhattan. Consulting this blog, given that A is a Q1 x M matrix where each row is a point of dimensionality M with Q1 points and B is a Q2 x M matrix where each row is also a point of dimensionality M with Q2 points, we can efficiently compute a distance matrix D(i, j) where the element at row i and column j denotes the distance between row i of A and row j of B using the following matrix formulation:

    nA = sum(A.^2, 2); %// Sum of squares for each row of A
    nB = sum(B.^2, 2); %// Sum of squares for each row of B
    D = bsxfun(@plus, nA, nB.') - 2*A*B.'; %// Compute distance matrix
    D = sqrt(D); %// Compute square root to complete calculation
    

    Therefore, if we let A be a matrix of query points and B be the dataset consisting of your original data, we can determine the k closest points by sorting each row individually and determining the k locations of each row that were the smallest. We can also additionally use this to retrieve the actual points themselves.

    Therefore:

    %// Load the data and create the query points
    load fisheriris;
    x = meas(:,3:4);
    newpoints = [5 1.45; 7 2; 4 2.5; 2 3.5];
    
    %// Define k and other variables
    k = 10;
    Q = size(newpoints, 1);
    M = size(x, 2);
    
    nA = sum(newpoints.^2, 2); %// Sum of squares for each row of A
    nB = sum(x.^2, 2); %// Sum of squares for each row of B
    D = bsxfun(@plus, nA, nB.') - 2*newpoints*x.'; %// Compute distance matrix
    D = sqrt(D); %// Compute square root to complete calculation 
    
    %// Sort the distances 
    [d, ind] = sort(D, 2);
    
    %// Get the indices of the closest distances
    ind_closest = ind(:, 1:k);
    
    %// Also get the nearest points
    x_closest = permute(reshape(x(ind_closest(:), :).', M, k, []), [2 1 3]);
    

    We see that we used the logic for computing the distance matrix is the same but some variables have changed to suit the example. We also sort each row independently using the two input version of sort and so ind will contain the IDs per row and d will contain the corresponding distances. We then figure out which indices are the closest to each query point by simply truncating this matrix to k columns. We then use permute and reshape to determine what the associated closest points are. We first use all of the closest indices and create a point matrix that stacks all of the IDs on top of each other so we get a Q * k x M matrix. Using reshape and permute allows us to create our 3D matrix so that it becomes a k x M x Q matrix like we have specified. If you wanted to get the actual distances themselves, we can index into d and grab what we need. To do this, you will need to use sub2ind to obtain the linear indices so we can index into d in one shot. The values of ind_closest already give us which columns we need to access. The rows we need to access are simply 1, k times, 2, k times, etc. up to Q. k is for the number of points we wanted to return:

    row_indices = repmat((1:Q).', 1, k);
    linear_ind = sub2ind(size(d), row_indices, ind_closest);
    dist_sorted = D(linear_ind);
    

    When we run the above code for the above query points, these are the indices, points and distances we get:

    >> ind_closest
    
    ind_closest =
    
       120   134    53    73    84    77    78    51    64    87
       123   119   118   106   132   108   131   136   126   110
       107    62    86   122    71   127   139   115    60    52
        99    65    58    94    60    61    80    44    54    72
    
    >> x_closest
    
    x_closest(:,:,1) =
    
        5.0000    1.5000
        6.7000    2.0000
        4.5000    1.7000
        3.0000    1.1000
        5.1000    1.5000
        6.9000    2.3000
        4.2000    1.5000
        3.6000    1.3000
        4.9000    1.5000
        6.7000    2.2000
    
    
    x_closest(:,:,2) =
    
        4.5000    1.6000
        3.3000    1.0000
        4.9000    1.5000
        6.6000    2.1000
        4.9000    2.0000
        3.3000    1.0000
        5.1000    1.6000
        6.4000    2.0000
        4.8000    1.8000
        3.9000    1.4000
    
    
    x_closest(:,:,3) =
    
        4.8000    1.4000
        6.3000    1.8000
        4.8000    1.8000
        3.5000    1.0000
        5.0000    1.7000
        6.1000    1.9000
        4.8000    1.8000
        3.5000    1.0000
        4.7000    1.4000
        6.1000    2.3000
    
    
    x_closest(:,:,4) =
    
        5.1000    2.4000
        1.6000    0.6000
        4.7000    1.4000
        6.0000    1.8000
        3.9000    1.4000
        4.0000    1.3000
        4.7000    1.5000
        6.1000    2.5000
        4.5000    1.5000
        4.0000    1.3000
    
    >> dist_sorted
    
    dist_sorted =
    
        0.0500    0.1118    0.1118    0.1118    0.1803    0.2062    0.2500    0.3041    0.3041    0.3041
        0.3000    0.3162    0.3606    0.4123    0.6000    0.7280    0.9055    0.9487    1.0198    1.0296
        0.9434    1.0198    1.0296    1.0296    1.0630    1.0630    1.0630    1.1045    1.1045    1.1180
        2.6000    2.7203    2.8178    2.8178    2.8320    2.9155    2.9155    2.9275    2.9732    2.9732
    

    To compare this with knnsearch, you would instead specify a matrix of points for the second parameter where each row is a query point and you will see that the indices and sorted distances match between this implementation and knnsearch.


    Hope this helps you. Good luck!