K-Nearest Neighbors explained
I usually see lots of students and developers trying to get into Machine Learning confused with complicated topics they are facing at the very beginning of their journey. I want to make a deep yet understandable introduction to the algorithm which is so simple and elegant that you would like it. If you are a Machine Learning engineer but have a limited understanding of this one, it could be also useful to read it.
I was working as a software developer for years and everyone around me was talking about this brand new data science and machine learning thing (I understood that there is nothing new on this planet later), so I've decided to take masters studies in the University to get known to it.
Our first module was a general introductory course to Data Science and I remember myself sitting and trying to understand what's going on. I knew a little about the field itself, it's history and potential and I found hard to understand how it works.
So I was wondering: how could a machine tell whether there is a kitten or dog on the picture? My instruments as a programmer were variables, functions, conditions, loops, abstractions for years. I was very comfortable with them but how can I actually apply this knowledge to this kind of problem? All the scenarios were pre-defined by me as a programmer. I knew all the ouptuts. How do I tell kitten from dog?
Well, I've seen a lot of them in my life. I probably cannot tell one fish from another but I'm pretty good at classifying domestic pets. I've seen hundreds of cats and dogs in my life and they are different for me - different eyes, different ears, paws and tails, different sounds. Those things are called features. The machine should rely on some features too. One dog is similar to another - they have features in common. Spam letters have certain words in common.
Ok, closer to the algorithms. There is intuition that if properly plotted, elements of one class would create clusters on the plot.
Hence if an element is of unknown class - from a testing dataset - we could say with some level of certainity that it will have the same class as the majority of its __ k nearest neighbors__ on the plot. This is KNN in a nutshell. Nothing more.
If 4 people around me during lunch are a Data Science students and one - a Psychology student, then, I, most probably am a Data Science program student too.
Easy so far, let's make things more complicated: KNN could be used for regression problems too. Getting back to the kittens: it's not about predicting a class animal falls into, but rather about predicting animal's mass from the depth and size of trace in the snow. In the case of regression, the estimation of the element's value would be a mean of its k nearest neighbors' values.
What are that "neighbors"? Easy! They are points from a training dataset with a known target variable value!
You feel it right - KNN algorithm does not require training. It will perform all the calculations during the classification or regression progress.
Some of you may already feel something bad about this "majority voting" thing. If there is much more entries of one class in population than entries of another class this majority voting could be actually compromised because of the bigger density of entries that could make KNN misclassify elements - same story as misclassification of Pluto as a planet, because we've had very limited knowledge about Kuiper's belt existence.
Also, you may be wondering about that distance word. What kind of distance we are talking about?
You are right - there is a lot of ways to calculate distance. For simplicity, we would talk about plain old Euclidean (L2 distance) in this post, but you will need to know (this was crazy for me) that there exist other methods to calculate the distance (refer to this article and Taxicab Geometry). The way to calculate distance which you would choose can affect the set of nearest neighbors chosen for the exact data point. More details could be found in those slides - image below is taken from them.
By definition, Euclidean distance between two points (in Euclidean space) is the length of the straight line that connects those two points.
Let's move straight to practice to be more clear.
Consider Iris dataset that contains of 50 samples from each of three species of Iris (Iris setosa, Iris virginica, and Iris versicolor) containing the data about the length and the width of the sepals and petals.
The goal is to classify the Iris flower according to those measurements. Here is a sample of this dataset:
5.5,3.5,1.3,0.2,Iris-setosa
7.7,2.6,6.9,2.3,Iris-virginica
7.2,3.6,6.1,2.5,Iris-virginica
It's always smart to look at the data visualization first to see whether the intuition that there is a significant difference between the flowers of three classes is true. Look at these plots from Wikipedia's page about Iris
Flowers creates three distinctive groups - in other words, colored dots don't look like a messed chaos. That a good sign that we will succeed.
We are dealing with 4-dimensional Euclidean space. Let us denote flower from the known dataset as and flower that is being classified as . Their position in this Euclidean space is described using the coordinates (basically our features).
Using Euclidean distance formula we can calculate Euclidean distance between those points in space (I promise this would be the last math formula here).
In our case we have 4 features ( in the formula above). I will remove summation operator and write this equation in terms of coordinates described above:
Ok, the last unanswered part is "How do we choose that magical k"? The answer is idk. Smaller k values in most cases is highly affected to noise in the dataset - this is called a model with a high variance or simply overfitted model. Bigger k values lead to bigger bias of the model meaning that it would ignore the training dataset. The general approach is to use , where is the size of the training dataset. It's also useful to always keep this number odd - to ensure the majority during voting for classification (note that this is actually not required for regression problems - at least, in terms of the general approach, described above). k in KNN is a hyperparameter and you need to choose it manually as a designer of a system. You can use Random Search, Cross-Validation or some of the fancy hyperparameter optimization techniques but they are subject to other topics.
Depending on the random split (not the actual split size, but random ordering performed) this simple algorithm would yield 98% to 100% accuracy on the testing set.
Ok, it was cool, but what would happen if we would use a bigger dataset? The answer is sad: we would need to calculate each distance again and again and it would be much slower. This is the strongest limitation of KNN algorithm - it is simply not effective on big datasets. It does not "learn" anything from the data. From this problem comes up another - it does not generalize well. Also, choosing the distance calculation approach and k number could affect accuracy significantly. Do not forget that noisy data would also make it less accurate or simply non-functioning at all (as tons of other, more sophisticated algorithms as well, though).
What can we do if the border of the two classes is very messed with elements of different classes? Well, one standard way is to apply weights to the "votes" such as to make closer points more significant.
Still, this algorithm has a lot of practical use - from oldest spam filters to transaction-scrutinizing software applications where KNN is used to analyze transaction register data to spot and indicate suspicious activity.
Thank you.
Good article Mr. Volodymyr!
You have explained this concept very well… Thank you 😊