K-Nearest Neighbors (KNN) Algorithm in Machine Learning

K-Nearest Neighbor (KNN) is a supervised Machine Learning algorithm that can solve classification and regression problems. It is one of the oldest ML algorithms and is still widely used due to its simplicity. KNN is unique in that it does not explicitly map input variables to target variables during the learning process, making it even more interesting. It’s a non-parametric algorithm, and we will know the complete working of the KNN algorithm in this blog.

Key Takeaways from this blog

In this article, we will cover the following topics related to the KNN algorithm in machine learning:

  • What is KNN, and how does it work?
  • Why is KNN considered an instance-based or lazy learning algorithm and a non-parametric algorithm?
  • How does the value of K impact the performance of the KNN algorithm?
  • How does feature scaling affect KNN?
  • Implementing KNN in Python.

So let’s start with knowing KNN formally.

What is K-Nearest Neighbors Algorithm in ML?

K-Nearest Neighbors, popularly known as KNN, is an unsupervised learning algorithm capable of solving classification and regression problems. It is a unique type of algorithm and does not follow the traditional approaches of mapping functions on input and output data. Instead, it simply memorizes the training data and uses this information to classify new test samples based on their similarity to the learned training samples.

KNN can be considered an algorithm that relies on memorization rather than a more traditional approach to machine learning. It is very effective and widely used among ML professionals. One of the key characteristics it possesses is its explainability. By explainability, we mean the reason the algorithm predicted any particular value can be explained. This makes the algorithm more useful.

Why is KNN instance-based learning?

Instance-based learning, also known as memory-based learning, refers to the KNN algorithm comparing new data samples to training data samples stored in its memory rather than using an explicit generalization.

It is sometimes called a “lazy” algorithm because it only performs computation after receiving new observations. This means that KNN stores all the training data in its memory and defers calculations until a new test sample is given for classification or regression.

Why is KNN a non-parametric algorithm?

KNN is classified as a non-parametric algorithm because it does not learn the parametric values from the data to map functions on input and output. Instead, it learns the entire training set. If more instances are introduced in the future, the learning will change significantly. This is a characteristic of non-parametric algorithms.

What are the common assumptions in KNN?

The KNN algorithm makes two main assumptions:

  1. Every sample in the training data is mapped to a real n-dimensional space, where each sample has the same number of attributes or dimensions.
  2. The “nearest neighbors” are defined using a specific distance measure, such as Euclidean distance, Manhattan distance, or Hamming distance. The choice of distance measure can significantly impact the predictions made by the KNN algorithm.

What are the various distances we can use to find neighbors in KNN algorithms?

Working of K-Nearest Neighbor

To understand how the KNN algorithm works, let’s consider the steps involved in using KNN for classification:

Step 1: We first must select the number of neighbors we want to consider. The term K in the KNN algorithm highly affects the prediction.

Step 2: We need to find the “K” neighbors based on any distance metric. It can be Euclidean, Manhatten, or our custom distance metric. We will have the “test sample” on which we want the prediction. The closest K samples in the training data from this “test sample” will be our K neighbors.

Step 3: We need to count how many neighbors are from how many different classes among the selected K neighbors. It can be possible that all K neighbors are from a single class. 

Step 4: Now, the objective of sending the test sample to KNN was to assign a class to that. This will be decided based on the majority classes followed by the neighbors. For example, if out of K neighbors, more than K/2 samples are from class-1 and the rest are from class-2, we will assign class-1 to the test sample.

We performed the prediction in these four simple steps. In summary, the KNN algorithm at the training phase stores the entire training dataset, and when it gets a new query, it classifies that query into a class similar to most neighbors. 

How the value of K affects the KNN algorithm?

In the KNN algorithm, the value of K can range from 1 to the total number of samples. A small value of K means the model is prone to overfitting and vulnerable to outliers. This model will have high variance and low bias. On the other hand, a model with a high value of K will have low variance and high bias, resulting in underfitting. As the value of K slowly increases from 1 to the number of training samples, the model will smooth out the boundary surfaces.

K = 1: A model with K=1 will have 0 training error and hard boundaries for determining the class of test query.

K = len(sample data): This model will be highly biased towards the majority class (with more samples) and less accurate.

Note: Keeping the K values as odd is advisable to reduce the chances of getting a tie. For example, if we choose K = 6, and 3 neighbors are from class1 and 3 are from class2, then the algorithm needs to take care of this case. But if we choose K odd, it will never fall into that situation.

The effect of k in KNN classification algorithm

How does feature scaling affect KNN?

K-Nearest Neighbor depends highly on the distance between data samples; hence scaling plays a vital role here. Suppose we train the KNN algorithm on unscaled data and different attributes lie in different scales. It will make our model biased towards the features with lesser magnitude values. 

To avoid that, it is always advisable to standardize the attributes before applying the KNN algorithm. For a better understanding, please look at this blog to visualize how scaling can affect distance calculation.

How scaling features affect the KNN prediction?

K-Nearest Neighbor for Regression problems

So far, we have discussed how we could use the KNN algorithm to solve the classification tasks, but this machine learning algorithm can also solve regression problems. We need to tweak the approach slightly. Instead of counting the K nearest neighbor class labels, what if we average the output labels for K neighbors?

Yes! It will act as the regression model in such a scenario.

How does KNN work as regression algorithm?

For example, let’s say we have test data X for which we want to predict the continuous variable Y. We have finalized that our neighbors can only be 3 (i.e., K=3) X1 → Y1, X2 → Y2, X3 → Y3.

KNN is a supervised learning algorithm; hence, we will always have the corresponding labels for the input variables while training. At the time of prediction, we can average the three labels to find the corresponding label of the test data, which is nothing but the predicted value for that test sample. Here, it will be Y = (Y1 + Y2 + Y3)/3. This averaging can be replaced with median, mode, or custom approach techniques.

Advantages of the KNN algorithm

The KNN algorithm is well-known for its simplicity and has several key strengths, including:

  • Zero training time: KNN requires very little training time compared to other machine learning algorithms.
  • Sample efficiency: KNN does not require a large training sample size.
  • Explainability: It is easy to understand the reasoning behind KNN’s predictions at each step.
  • Ease of adding and removing data: With KNN, you can easily update the memory and perform inference without retraining the model, as other machine learning algorithms require.
  • Insensitivity to class imbalance: If a dataset has significantly more instances of one class than others, KNN is less affected by this imbalance than other machine learning algorithms.

Disadvantages of the KNN algorithm

No doubt, KNN is explainable, but this algorithm has some limitations. It is not the first choice among Machine Learning experts, and the reasons are:

  • Needs a lot of storage: K-Nearest Neighbor stores all the training data in its memory and performs inference based on that. It makes the algorithm unemployable on edge platforms. Think of a case when we will have millions of samples in a dataset. 
  • Predictions are Slow: The time complexity of KNN is O(dN), where d is the dimension of features and N is the total number of samples. The more data, the more will be the prediction time.
  • Irrelevant features can fool the nearest neighbor algorithms. 

KNN Implementation in Python using sklearn

Let’s implement the KNN algorithm in Python to solve a classification problem.

Step 1: Import the necessary dataset and required libraries.

We will use the famous Iris dataset here to perform the classification. I will be imported from the Scikit-learn datasets using load_iris. Other libraries are imported for training, preprocessing, and evaluation.

import matplotlib.pyplot as plt   # update the plot 
from sklearn import datasets# read the data 
import numpy as np #for arrays 

from sklearn.model_selection import train_test_split # split the data 
from sklearn.preprocessing import StandardScaler # scale the data 
from sklearn.neighbors import KNeighborsClassifier # the algorithm 

from sklearn.metrics import accuracy_score  #grade the results 
import pandas as pd 

iris = datasets.load_iris() # read the data 

X = iris.data[:]  # select the features to use 
y = iris.target   # select the classes


iris_dataframe = pd.DataFrame (data= np.c_[iris['data'], iris['target']],

	columns= iris['feature_names'] + ['target'])

plt.figure(2)
grr = pd.plotting.scatter_matrix(iris_dataframe,
	                              c=iris["target"], 
	                              figsize=(15, 15),
	                              marker='o', 
	                              S=60,
	                              alpha=.8)
plt.show(2)

Step 2: Understanding the data

This dataset contains four variables: sepal length, sepal width, petal length, and petal width, which describe iris plants of three types: Setosa, Versicolour, and Virginica. There are 150 observations in the dataset, each labeled with the actual type of flower.

Step 3: Visualization

Features of this dataset have four dimensions, which can be visualized using a pairwise scatter plot matrix to distinguish. This plot helps to visualize the relationships among multiple variables within subdivisions of the dataset. In the image below, the violet color represents the Setosa class, green represents the Versicolour class, and yellow represents the Virginica class.

How to draw the pair plot of IRIS dataset?

Step 4: Data Preprocessing

The entire dataset is initially split into the training and testing part using the traintestsplit function of Scikit-learn. A function StandardScalar( ) is used to standardize the data (column-wise). When fit to a dataset, this function will transform the dataset to mean μ = 0 and standard deviation σ = 1.

A dataset having N samples and m features,

Distance calculation using euclidean distance formula.

Thus every data is then updated as,

X(i,j) = ((X(i,j) - μ(i))/σ(i)) ; for all i in m, and j in N
X_train, X_test, y_train, y_test = train_test_split(X,y, test_size=0.25,random_state=0)


SC = StandardScaler()# create the standard scaler 
sc.fit(X_train) # fit to the training data 
x_train_std = sc.transform(X_train) # transform the training data 
X_test_std = sc.transform(X_test) # same transformation on test data

Step 5: KNN Model Fitting and Performance Evaluation

We will fit the KNN model for different K values ranging from 1 to the number of samples in the testing dataset. The metric “Minkowski” along with p = 2 represents the Euclidean distance in the R-space. The model will be fitted on different values of K and then be used to predict the output for a test sample size.

accuracyTest = {}; accuracy Train = {} 

for k in range (len (y_test):

	knn = KNeighborsClassifier(n_neighbors=k+1, p=2, metric='minkowski')
	knn.fit(X_train_std,y_train)
	y_pred = knn.predict(x_test_std) 
	y_train_pred = knn.predict(X_train_std) 

	if (k+1)%10==0:
		print(10*'-')
		print("For k = %s" %(k+1))
		print('Number in test ', len(y_test))
		print('Misclassified samples: %d' % (y_test != y_pred).sum())

	accTrain = accuracy_score(y_train,y_train_pred)
	acc = accuracy_score(y_test, y_pred)
	accuracyTest[k+1] = acc
	accuracyTrain[k+1] = accTrain

for accuracy in [accuracy Train, accuracy Test]:
	lists = sorted(accuracy.items() # sorted by key, return a list of tuples 
	X, y = zip(*lists) # unpack a list of pairs into two tuples 
	plt.plot(x, y)
	plt.show()

training and testing accuracy of the KNN algorithm on Iris dataset

If we prioritize the testing accuracy, the K > 18 decreases the testing accuracy sharply. The optimal number of neighbors can be around 15 to 18.

Decision Boundaries for KNN

The two datasets (training and testing) are combined to show the effect of varying K in the KNN algorithm. Only two features (petal length and width) are considered for visualization. The value of K taken is [1,25,50,75,100,112], where the training sample size is 112. The decision boundary at K = 112 returns the majority of the three classes, which is red.

X = iris.data[:, [2,3]] # select the features to use 
y = iris.target 		# select the classes

X_train, X_test, y_train, y_test = train_test_split(X,y, test_size=0.25,random_state=0)


SC = StandardScaler()# create the standard scaler 
sc.fit(X_train) # fit to the training data 
x_train_std = sc.transform(X_train) # transform the training data 
X_test_std = sc.transform(X_test) # same transformation on test data

X_combined_std = np.vstack((X_train_std, X_test_std))
y_combined = np.hstack((y_train, y_test))

print('Number in combined ', len(y_combined))
# check results on combined data 
y_combined_pred = knn.predict(X_combined_std)

print('Misclassified combined samples: %d' 1 % (y_combined != y combined_pred). sum )
print('Combined Accuracy: %.2f' % accuracy_score(y_combined, y_combined_pred)) 
# visualize the results 

for k in [1,25,50, 100, len(X_train)]:

	knn = KNeighborsClassifier (n_neighbors=k, p=2, metric='minkowski')

	knn.fit(X_train_std, y_train) 

	plot_decision_regions(X=X_combined_std, y=y_combined, classifier=knn,
		                        test_idx=range(105,150))

	plt.xlabel('petal length [standardized]') 
	plt.ylabel('petal width [standardized]') 
	plt.title('k=%s'%k) 
	plt.legend(loc='upper left') 
	plt.show()

Decision boundary formation for KNN algorithm with respect to K

Industrial Applications of KNN

Although there are certain limitations, this algorithm is widely used in industries because of its simplicity. Some of these applications are:

  • Email spam filtering: For detecting the trivial and fixed types of spam emails, KNN can perform well. The implementation steps of this algorithm can be found here.
  • Wine Quality prediction: Wine quality prediction is a regression task and can be solved using the KNN algorithm. The implementation can be found here.
  • Recommendation system: KNN is used to build recommendation engines that recommend some products/movies/songs to the users based on their likings or dislikes.

Bonus Section: Voronoi cell and Voronoi diagrams

Other ML algorithms like linear regression, logistic regression, and SVMs try to fit a mapping function from input to output. This mapping function is also known as the Hypothesis function. But KNN is different. It does not form any explicit Hypothesis function but creates a hypothesis space. For a dataset in R², the hypothesis space is a polyhedron formed using the training samples. Let’s first understand what a Voronoi cell is.

What is Voronoi Cell?

Suppose the training set is “T,” and the elements of that training set are “x”. Then Voronoi Cell of xi is a polytope (a geometric shape with “flat” sides) consisting of all points closer to xi than any other points in T.

What is voronoi diagram and its prototype in KNN algorithm?

If we observe in the above image, initially, every cell contains a single sample which means K = 0, and as we increase the value of K, two cells merge and form a new polytope including K samples. Voronoi Cells cover the entire training space of T, and when we combine them, it will create Voronoi Diagram.

Possible Interview Questions

The KNN algorithm is known for its explainability, making it a popular interview topic. Some potential questions that may be asked about KNN include:

  • How does KNN differ from other machine learning algorithms?
  • Can changing the distance metric affect the classification accuracy of KNN?
  • Is the KNN algorithm susceptible to data normalization?
  • Why is KNN classified as a non-parametric algorithm?
  • What are the main drawbacks of the KNN algorithm?

Conclusion

In this article, we covered the step-wise working of the K-Nearest Neighbor (KNN) algorithm, one of the first machine learning algorithms ever developed. We discussed how KNN defines instances as neighbors and how the value of K impacts the predictions. We also emphasized the importance of feature scaling and how KNN can be used for regression tasks. In the last, we implemented the algorithm on the well-known Iris dataset in Python. We hope you found this article informative and enjoyable.

References

  1. Scikit-learn: Machine Learning in Python, Pedregosa, et al., JMLR 12, pp. 2825–2830, 2011
  2. Mitchell, T. M. (1997). Machine learning., McGraw Hill series in computer science New York: McGraw-Hill.
  3. UCI Machine Learning Repository: Iris Data Set.
  4. J. D. Hunter, “Matplotlib: A 2D Graphics Environment”, Computing in Science & Engineering, vol. 9, no. 3, pp. 90–95, 2007.

Enjoy Learning! Enjoy Algorithms!

Share Your Insights

☆ 16-week live DSA course
☆ 16-week live ML course
☆ 10-week live DSA course

More from EnjoyAlgorithms

Self-paced Courses and Blogs

Coding Interview

Machine Learning

System Design

Our Newsletter

Subscribe to get well designed content on data structure and algorithms, machine learning, system design, object orientd programming and math.