I have a very basic implementation of k-means in javascript (I know but it needs to run in the browser). What I would like to understand is - how could one make this more functional?
It is currently full of loops, and extremely difficult to follow / reason about, code below:
export default class KMeans {
constructor(vectors, k) {
this.vectors = vectors;
this.numOfVectors = vectors.length;
this.k = k || bestGuessK(this.numOfVectors);
this.centroids = randomCentroids(this.vectors, this.k);
}
classify(vector, distance) {
let min = Infinity;
let index = 0;
for (let i = 0; i < this.centroids.length; i++) {
const dist = distance(vector, this.centroids[i]);
if (dist < min) {
min = dist;
index = i;
}
}
return index;
}
cluster() {
const assigment = new Array(this.numOfVectors);
const clusters = new Array(this.k);
let movement = true;
while (movement) {
// update vector to centroid assignments
for (let i = 0; i < this.numOfVectors; i++) {
assigment[i] = this.classify(this.vectors[i], euclidean);
}
// update location of each centroid
movement = false;
for (let j = 0; j < this.k; j++) {
const assigned = [];
for (let i = 0; i < assigment.length; i++) {
if (assigment[i] === j) assigned.push(this.vectors[i]);
}
if (!assigned.length) continue;
const centroid = this.centroids[j];
const newCentroid = new Array(centroid.length);
for (let g = 0; g < centroid.length; g++) {
let sum = 0;
for (let i = 0; i < assigned.length; i++) {
sum += assigned[i][g];
}
newCentroid[g] = sum / assigned.length;
if (newCentroid[g] !== centroid[g]) {
movement = true;
}
}
this.centroids[j] = newCentroid;
clusters[j] = assigned;
}
}
return clusters;
}
}
It certainly can.
You could start with this:
classify(vector, distance) {
let min = Infinity;
let index = 0;
for (let i = 0; i < this.centroids.length; i++) {
const dist = distance(vector, this.centroids[i]);
if (dist < min) {
min = dist;
index = i;
}
}
return index;
}
Why is this a member function? Wouldn't a pure function const classify = (centroids, vector, distance) => {...}
be cleaner?
Then for an implementation, let's change the distance
signature a bit. If we curry it to const distance = (vector) => (centroid) => {...}
, we can then write
const classify = (centroids, vector, distance) =>
minIndex (centroids .map (distance (vector)))
And if that distance
API is out of our control, it's not much harder:
const classify = (centroids, vector, distance) =>
minIndex (centroids .map (centroid => distance (vector, centroid)))
Granted, we haven't written minIndex
yet, but we've already broken the problem down to use a more meaningful abstraction. And minIndex
isn't hard to write. You can do it imperatively as the original classify
function did, or with something like this:
const minIndex = (xs) => xs.indexOf (Math.min (...xs))
Note that distance
is a slightly misleading name here. I had to read it more carefully because I assumed a name like that would represent..., well a distance. Instead it's a function used to calculate distance. Perhaps the name metric
or something like distanceFunction
, distanceFn
, or distanceImpl
would be more obvious.
Now let's move on to this bit:
const newCentroid = new Array(centroid.length);
for (let g = 0; g < centroid.length; g++) {
let sum = 0;
for (let i = 0; i < assigned.length; i++) {
sum += assigned[i][g];
}
newCentroid[g] = sum / assigned.length;
if (newCentroid[g] !== centroid[g]) {
movement = true;
}
}
This code has two responsibilities: creating the newCentroid
array, and updating the value of movement
if any value has changed.
Let's separate those two.
First, creating the new centroid. We can clean up that nested for
-loop to something like this:
const makeNewCentroid = (centroid, assigned) =>
centroid .map ((c, g) => mean (assigned .map ((a) => a[g])))
This depends on a mean
function, which we'll write along with its required sum
function like this:
const sum = (ns) => ns .reduce ((t, n) => t + n, 0)
const mean = xs => sum (xs) / xs.length
Then we need to update movement
. We can do that easily based on centroids
and newCentroids
:
movement = centroids.some((c, i) => c !== newCentroids[i])
Obviously, you can continue in this manner. Each for
loop should have a fundamental purpose. Find that purpose and see if one of the Array.prototype
methods could better express it. For the second section we worked with above, we found two purposes, and just split them into two separate blocks.
This should give you a good start on making this more functional. There is no magic bullet. But if you think in terms of pure functions on immutable data, and on strong separation of concerns, you can usually move in a functional direction.