Search code examples
javascriptalgorithmfunctional-programmingk-means

How to make k-means algorithm functional


I have a very basic implementation of k-means in javascript (I know but it needs to run in the browser). What I would like to understand is - how could one make this more functional?

It is currently full of loops, and extremely difficult to follow / reason about, code below:

export default class KMeans {
  constructor(vectors, k) {
    this.vectors = vectors;
    this.numOfVectors = vectors.length;
    this.k = k || bestGuessK(this.numOfVectors);
    this.centroids = randomCentroids(this.vectors, this.k);
  }

  classify(vector, distance) {
    let min = Infinity;
    let index = 0;

    for (let i = 0; i < this.centroids.length; i++) {
      const dist = distance(vector, this.centroids[i]);
      if (dist < min) {
        min = dist;
        index = i;
      }
    }

    return index;
  }

  cluster() {
    const assigment = new Array(this.numOfVectors);
    const clusters = new Array(this.k);

    let movement = true;

    while (movement) {
      // update vector to centroid assignments
      for (let i = 0; i < this.numOfVectors; i++) {
        assigment[i] = this.classify(this.vectors[i], euclidean);
      }

      // update location of each centroid
      movement = false;
      for (let j = 0; j < this.k; j++) {
        const assigned = [];

        for (let i = 0; i < assigment.length; i++) {
          if (assigment[i] === j) assigned.push(this.vectors[i]);
        }

        if (!assigned.length) continue;
        const centroid = this.centroids[j];
        const newCentroid = new Array(centroid.length);

        for (let g = 0; g < centroid.length; g++) {
          let sum = 0;
          for (let i = 0; i < assigned.length; i++) {
            sum += assigned[i][g];
          }
          newCentroid[g] = sum / assigned.length;

          if (newCentroid[g] !== centroid[g]) {
            movement = true;
          }
        }
        this.centroids[j] = newCentroid;
        clusters[j] = assigned;
      }
    }

    return clusters;
  }
}

Solution

  • It certainly can.

    You could start with this:

      classify(vector, distance) {
        let min = Infinity;
        let index = 0;
    
        for (let i = 0; i < this.centroids.length; i++) {
          const dist = distance(vector, this.centroids[i]);
          if (dist < min) {
            min = dist;
            index = i;
          }
        }
    
        return index;
      }
    

    Why is this a member function? Wouldn't a pure function const classify = (centroids, vector, distance) => {...} be cleaner?

    Then for an implementation, let's change the distance signature a bit. If we curry it to const distance = (vector) => (centroid) => {...}, we can then write

    const classify = (centroids, vector, distance) =>
      minIndex (centroids .map (distance (vector)))
    

    And if that distance API is out of our control, it's not much harder:

    const classify = (centroids, vector, distance) =>
      minIndex (centroids .map (centroid => distance (vector, centroid)))
    

    Granted, we haven't written minIndex yet, but we've already broken the problem down to use a more meaningful abstraction. And minIndex isn't hard to write. You can do it imperatively as the original classify function did, or with something like this:

    const minIndex = (xs) => xs.indexOf (Math.min (...xs))
    

    Note that distance is a slightly misleading name here. I had to read it more carefully because I assumed a name like that would represent..., well a distance. Instead it's a function used to calculate distance. Perhaps the name metric or something like distanceFunction, distanceFn, or distanceImpl would be more obvious.


    Now let's move on to this bit:

    const newCentroid = new Array(centroid.length);
    
    for (let g = 0; g < centroid.length; g++) {
      let sum = 0;
      for (let i = 0; i < assigned.length; i++) {
        sum += assigned[i][g];
      }
      newCentroid[g] = sum / assigned.length;
    
      if (newCentroid[g] !== centroid[g]) {
        movement = true;
      }
    }
    

    This code has two responsibilities: creating the newCentroid array, and updating the value of movement if any value has changed.

    Let's separate those two.

    First, creating the new centroid. We can clean up that nested for-loop to something like this:

    const makeNewCentroid = (centroid, assigned) =>
      centroid .map ((c, g) => mean (assigned .map ((a) => a[g])))
    

    This depends on a mean function, which we'll write along with its required sum function like this:

    const sum = (ns) =>  ns .reduce ((t, n) => t + n, 0)
    const mean = xs => sum (xs) / xs.length
    

    Then we need to update movement. We can do that easily based on centroids and newCentroids:

    movement = centroids.some((c, i) => c !== newCentroids[i])
    

    Obviously, you can continue in this manner. Each for loop should have a fundamental purpose. Find that purpose and see if one of the Array.prototype methods could better express it. For the second section we worked with above, we found two purposes, and just split them into two separate blocks.

    This should give you a good start on making this more functional. There is no magic bullet. But if you think in terms of pure functions on immutable data, and on strong separation of concerns, you can usually move in a functional direction.