Explain K Nearest Neighbor classification Algorithm in detail? Implement KNN from scratch without using scikit-learn.

Medium Last updated on May 7, 2022, 1:23 a.m.

K Nearest Neighbor classifier is a non-parametric classifier that simply stores the training data $D$ and classifies each new instance x using a majority vote over its’ set of $K$ nearest neighbors $N_{K} (x)$ computed using any distance function $d : R_{D} × R_{D} → R$. The mathematical equation of KNN classifier can be written as:

$$f_{KNN}(x) =argmax_{y \in \mathcal{\hat{y}}} \sum_{i=N_{k}(x)}{} I[y_{i}==\hat{y}]$$

Using the KNN classifier requires selecting the best distance function d and the number of neighbors K by performing hyper-parameter tuning. In general, KNN can work with any distance function d, which can satisfy non-negativity $d(x, x′) ≥ 0$ and identity of indiscernibles $d(x, x) = 0$. Similarly, KNN can also work with any similarity function s satisfying non-negativity $s(x, x’) ≥ 0$ that attains its maximum on indiscernibles $s(x, x) = max_{x′} s(x, x′)$.

The standard choice of distance metrics is Minkowski Distance ($l_p$ norms); Given two data vectors $x, x’ \in \mathbb{R}^D$:

$$d_p(x, x’) = || x-x’||_{p}$$

$$d_p(x, x’) = \bigg( \sum_{i=1}^{D} |x_d-x’_{d}|^p \bigg)^{1/p}$$

This formula turns into euclidean distance if p = 2, Manhattan distance if p = 1, and Chebyshev distance if p = ∞.

Now, lets’ go over the steps of algorithm:

K Nearest Neighbor Algorithm

• Given a distance function d, compute the distance $d_{i} = d(x_{i}, x^{*})$ from a target point $x^{∗}$ to all of the training points $x_{i}$.

• Sort the distances ${d_{i}, i = 1: N}$ and choose the data cases with the K smallest distances to form the neighbor set $N_K ( x^{*} )$. Note: We can also use a similarity function, in that case, select the K most similar data cases(max similarity score).

• Once the K neighbors are selected, apply the classification rule and label the data points $x_i$

Note: In the above section, we have explained the brute force approach, but instead of a brute force nearest neighbor search, data structures like ball trees can be constructed over the training data that support nearest neighbor search with lower amortized computational complexity.

How to determine the optimal value of K

The general rule of choosing the value of K is $K = \frac{\sqrt{N}}{2}$, where N stands for the number of samples in your training dataset. Another suggestion is to keep the value of K odd in order to avoid any tie between choosing a class. However, if situations like this arise frequently, that means the training data is highly correlated between classes, and using a simple classification algorithm such as KNN would not result in a good performance.

In Industry, Its’ suggested determining K with the help of the error curve and accuracy curve. To start, initialize a random K value and start plotting between error rate and K denoting values in a defined range. Later, choose the value of K as having a minimum error rate. Followed by that, derive a plot between accuracy and K denoting values in a defined range. Then choose the K value as having a maximum accuracy.

KNN is generally used for baseline modeling when there is less amount of data. It works well if we don’t have enough data to learn model parameters. To understand further about KNN, let’s learn about its’ tradeoffs.

Whats are the trade-offs of K Nearest Neighbor Classification

• Low bias: Converges to the correct decision boundaries as the training data increases.
• High variance: Lots of variability in the decision boundaries when the amount of training data is low.
• Curse of dimensionality: Due to limitations of distance function (d), KNN has the curse of dimensionality, as all points are scattered far in high dimensions.
• Space and Time Complexity: In KNN, we need to store all training data and perform the neighbor searches. This can make the algorithm use a lot of memory and take a lot of time.

Write K Nearest Neighbor Classification code from scratch.

To begin, we will be using the basic library functions of Numpy for array manipulation and Scipy for mode calculation.