k-means clustering simplified
Lion Ralfs — Posted onI've come across the term "k-means clustering" multiple times now, always either reading past it or brushing it off as some complicated machine learning technique.
Well, last week I read about someone using it for a problem similar to the one I was having, so it was time to finally look up what all of this "k-means clustering" was about.
When confronted with an unfamiliar concept, I've recently been building a habit of learning about it by implementing a simple, naive algorithm to solve it. So this is my attempt at explaining the concept, along with a very simple implementation in everyone's favorite programming language, JavaScript.
What
Imagine you have a 2D grid with some points scattered across it:
In essence, clustering is just a fancy word for grouping things together, and the "means" in k-means describes the arithmetic mean, also known as the average. When we do k-means clustering, we group our data into k groups (=clusters) by using the average somewhere in the process.
The first thing we need to establish is the term vector. A vector is a list of numbers describing a data point. In our example above, each vector has two elements, representing the x and y coordinates as such: or . The dimension of a vector is the amount of elements it has (in our case, all of our vectors are 2-dimensional). The list of our data points is also known as observations.
class Vector {
/**
* @param {number} dimension
* @param {Array<number>} values
*/
constructor(dimension, values) {
this.dimension = dimension;
this.values = values;
}
}
Secondly, a cluster is a data structure containing vectors. a cluster also has a centroid, which is also just a complicated word for center.
class Cluster {
/**
* @param {number} dimension
* @param {Array<Vector>} vectors
*/
constructor(dimension, vectors) {
this.dimension = dimension;
this.vectors = vectors;
/** @type {Vector} */
this.center;
}
}
Let's go over how the center of a cluster is calculated. In short, the center is the mean of all the data points in the vector. For example, let's say we have a cluster with three elements, , and . The mean of a vector is calculated component-wise, so . Let's add a method to the Cluster
class to calculate the center:
class Cluster {
// ... constructor here ...
getCentroid() {
let result = new Vector(this.dimension, new Array(this.dimension).fill(0));
for (let observation of this.vectors) {
for (let dimension = 0; dimension < this.dimension; dimension++) {
result.values[dimension] += observation.values[dimension];
}
}
result.values = result.values.map((val) => val / this.vectors.length);
return result;
}
}
How
As mentioned earlier, when we do clustering, we assign each vector to a cluster. Let's say we want to cluster our dataset into k = 3 clusters. We create 3 clusters, as such:
- cluster 1
- cluster 2
- cluster 3
Now we could randomly assign each observation (=vector) to a cluster:
Obviously, this is not ideal as we want something like this:
The idea is to repeatedly assign each observation to the nearest cluster. The algorithm achieves this by repeating the following two steps: assign and update.
Assign
In the assignment step, each observation is assigned to the cluster with the closest distance to its center. As the center is also just a vector, we can calculate the distance between observation and cluster center using the Euclidean distance, which is the straight line between the two points:
As it contains a 90° angle, we can calculate the length of its hypotenuse using the Pythagorean theorem:
For reasons out of scope for this post, the final square root is not calculated, but it doesn't change the general idea of how we determine the distance between the two points:
/**
* @param {Vector} vector1
* @param {Vector} vector2
* @returns {number}
*/
function distanceBetween(vector1, vector2) {
let result = 0;
// assuming both vectors have the same size
let vectorSize = vector1.dimension;
let vector1Values = vector1.values;
let vector2Values = vector2.values;
for (let dimension = 0; dimension < vectorSize; dimension++) {
result += Math.pow(vector1Values[dimension] - vector2Values[dimension], 2);
}
return result;
}
To summarize, when we assign an observation to a cluster, we follow these steps:
- for each cluster :
- measure the distance between and 's center
- assign to the closest cluster
In our JavaScript code, this would be:
for (let observation of observations) {
let bestDistance = Infinity;
let bestCluster = undefined;
for (let cluster of clusters) {
let distance = distanceBetween(cluster.center, observation);
if (distance < bestDistance) {
bestDistance = distance;
bestCluster = cluster;
}
}
bestCluster.vectors.push(observation);
}
Update
Now that we've assigned each observation to a cluster, we need to recalculate the cluster's center reflecting all of the observations they now have. As we already have a getCenter
method on the Cluster
class, there's not much work to do:
clusters.forEach((cluster) => {
let newCenter = cluster.getCenter();
cluster.center = newCenter;
});
Repeat
All that is left to do is put the previous two steps in a loop. Take this pseudocode for example:
while (!done) {
assign();
update();
}
The more we loop, the more accurate our results become. But how do we determine when we're done? We could:
- do a fixed number of iterations
- measure how much the clusters centers shift in the update step and stop iterating as soon as they don't move further than a certain threshold
- track the assignments of observations to clusters, and if they don't change between iterations, we're done
For simplicity, let's stick with option 1, where we set the maximum iterations to 10. Our function to do the clustering might look like this:
/**
* Clusters `observations` of dimension `d` into `k` clusters
* @param {number} k
* @param {number} d
* @param {Array<Vector>} observations
* @returns {Array<Cluster>}
*/
function cluster(k, d, observations) {
if (k > observations.length) {
throw new Error('make sure k ≤ observations.length');
}
// generate k empty clusters
let clusters = Array.from({ length: k }, () => new Cluster(d, []));
// TODO: initialize the clusters
let maxIterations = 10;
while (maxIterations--) {
// first, reset all observations for each cluster
clusters.forEach((cluster) => {
cluster.vectors = [];
});
// step 1 (assignment):
// every iteration, assign each observation to the cluster
// with the smallest distance to the center
for (let observation of observations) {
let bestDistance = Infinity;
let bestCluster = undefined;
for (let cluster of clusters) {
let distance = distanceBetween(cluster.center, observation);
if (distance < bestDistance) {
bestDistance = distance;
bestCluster = cluster;
}
}
bestCluster.vectors.push(observation);
}
// step 2 (update):
// recalculate center
clusters.forEach((cluster) => {
let newCenter = cluster.getCenter();
cluster.center = newCenter;
});
}
return clusters;
}
One thing left to do is to initialize the clusters. There are generally two options:
- The Forgy method: for each cluster, pick a random observation as its initial center
- The random partition method: assign each observation to a random cluster and use that as a starting point
Let's use the Forgy method as our initialization method as it doesn't require us to make changes to our implementation:
/**
* The Forgy method randomly chooses k observations from the dataset
* and uses these as the initial means.
* @param {number} k
* @param {Array<Vector>} observations
* @returns {Array<Vector>}
*/
function forgy(k, observations) {
let indices = new Set();
while (indices.size < k) {
let candidate = Math.floor(Math.random() * observations.length);
if (!indices.has(candidate)) {
indices.add(candidate);
}
}
return [...indices].map((index) => observations[index]);
}
We can embed the Forgy method into the now completed version of our implementation:
Full implementation
/**
* Clusters `observations` of dimension `d` into `k` clusters
* @param {number} k
* @param {number} d
* @param {Array<Vector>} observations
* @returns {Array<Cluster>}
*/
function cluster(k, d, observations) {
if (k > observations.length) {
throw new Error('make sure k ≤ observations.length');
}
// generate k empty clusters
let clusters = Array.from({ length: k }, () => new Cluster(d, []));
// initialize clusters by picking random vectors as centers
let initialCenters = forgy(3, observations);
clusters.forEach((cluster, i) => {
cluster.center = initialCenters[i];
});
let maxIterations = 10;
while (maxIterations--) {
// first, reset all observations for each cluster
clusters.forEach((cluster) => {
cluster.vectors = [];
});
// step 1 (assignment):
// every iteration, assign each observation to the cluster
// with the smallest distance to the center
for (let observation of observations) {
let bestDistance = Infinity;
let bestCluster = undefined;
for (let cluster of clusters) {
let distance = distanceBetween(cluster.center, observation);
if (distance < bestDistance) {
bestDistance = distance;
bestCluster = cluster;
}
}
bestCluster.vectors.push(observation);
}
// step 2 (update):
// recalculate center
clusters.forEach((cluster) => {
let newCenter = cluster.getCenter();
cluster.center = newCenter;
});
}
return clusters;
}
// example usage:
let observations = [
new Vector(2, [0, 2]),
new Vector(2, [0, 5]),
new Vector(2, [1, 0]),
new Vector(2, [1, 6]),
new Vector(2, [2, 0]),
new Vector(2, [2, 5]),
new Vector(2, [2, 6]),
new Vector(2, [3, 1]),
new Vector(2, [3, 5]),
new Vector(2, [3, 6]),
new Vector(2, [4, 3]),
new Vector(2, [6, 3]),
new Vector(2, [6, 2]),
];
let clusters = cluster(3, 2, observations);
Unfortunately, the outcome of running the above example (and k-means clustering in general) heavily depends on the random choices that were made during the cluster initialization. However, you might receive the following three clusters. I've highlighted the respective cluster centers as slightly larger dots:
Conclusion
That's pretty much the gist of what k-means clustering is and how a naive algorithm operates. I've demonstrated the technique here using 2-dimensional vectors but obviously the same methods can be applied to vectors of higher dimensions. Also, please don't use my implementation when you want to do clustering in your project and select a battle-tested library.