Search code examples
javascriptalgorithmdata-structurespaginationtrie

How to build a trie for finding exact phonetic matches, sorted globally by weight, and paginated? (Building off this example)


Goal

I got pretty far working with AI to help me construct a Trie for finding rhyme words. Basically, say you have 10 million English words written using the CMU pronunciation system, where each phoneme is one node in the Trie (ARPABET in this image).

enter image description here

Here are the key features:

  1. Trie nodes are 1+ ASCII symbols. Like B for "b" sound and CH for "ch" sound, etc..
  2. Trie can paginate and jump from a particular page. So it takes a limit, and page property.
  3. Pagination can optionally be limited to matches of a particular phoneme sequence length. Like phonemes of g l a d would be 4.
  4. Trie search input (provided by user) is an array of phonemes. Phonemes are the units of sound in ASCII (like the ARPABET above).
  5. Input to the trie is expanded into all possible rhymes. It is expanded using a map of phoneme to array of phonemes (only partially implemented in the function below, it will take weeks to properly fine tune the values here, which I need TODO soon).
  6. We can call the expanded input the "rhymingPhonemeArray".
  7. Each phoneme sequence in the rhymingPhonemeArray is weighted. By "weighted", this weight is basically "how close the mapped phoneme is to the original phoneme". This is so we can tell "this phoneme sequence (with a lower cumulative weight) is a better rhyme than this other phoneme sequence in the rhymingPhonemeArray".

Problems

The problem(s) I'm facing now (with the solution we landed on, shared below in JavaScript), are:

  1. Traversing the entire trie to find all possible matches, which is unoptimized. Ideally it only traverses what it needs to (the page / limit amount and offset).
  2. The Trie sorts the entire match set afterwards, instead of doing what we want, which is to get the page/limit amount already sorted. This is the key problem, not sure if/how it can do this in an optimized way, or if it's even possible.
  3. The rhymingPhonemeArray is iterated through 😔, so if we are paginating and the rhymingPhonemeArray is like [G-A-D (cum-weight: 10), G-L-A-D (cum-weight: 24 or whatever), G-R-A-D (cum-weight: 29), etc.], you are going to find everything that rhymes with G-A-D first, then after that is paginated through, paginate through G-L-A-D, etc.. I would like to avoid this grouping. Instead, the cumulative weight needs to be sorted and paginated through the "global set" of all 10m words.

So for (3), it should find (something like this):

input: G-A-D
matches:
  G-A-D
  G-L-A-D
  A-G-A-D
  A-R-G-A-D
  G-R-A-D
  A-G-R-A-D
  A-R-G-L-A-D
  A-R-G-R-A-D
  ...

By "something like this", I mean, notice how it is NOT like this (where first, all the G-A-D matches are found, then all the G-L-A-D are found etc..):

input: G-A-D
matches:
  G-A-D
  A-G-A-D
  A-R-G-A-D
  G-L-A-D
  A-R-G-L-A-D
  G-R-A-D
  A-G-R-A-D
  A-R-G-R-A-D
  ...

Instead, in the first matches, it is more interwoven, which should be based on the global cumulative weight for each word.

Question

How can you modify the following "Trie Implementation" to solve the problems of 1, 2, and 3 above? Note: it doesn't need to be 100% exact, just looking for the key insight on how to solve this pagination problem properly (even just at a high level / pseudocode level). They're all aspects of the same underlying problem, which is what I already stated but will state again:

The Trie implementation below does not properly and somewhat efficiently paginate through the sorted-by-global-cumulative-weight words, which are exact matches to the rhymingPhonemeArray (which is generated from the phonemes input).

Is it even possible to solve this problem, without having to iterate over the entire Trie? Given (remember), the input is expanded into rhymingPhonemeArray (which could be a lot of possibilities, but we will practically limit the input to probably 3 syllables, beside the point). If it is not possible, can you explain why it's not possible?

If it is possible, how would you modify this trie to support pagination, and jumping to a specific page, without having to traverse everything, while at the same time the paginated results are globally sorted by the cumulative weight for each word?

Trie Implementation

The Trie that we landed on was this:

class TrieNode {
  constructor() {
    this.children = {}; // Store child nodes
    this.isWord = false; // Flag to check if node marks the end of a word
    this.word = null; // Store the word if this is an end node
    this.cumulativeWeight = 0; // Store the cumulative weight for sorting
    this.phonemeLength = 0; // Length of the phoneme sequence
  }
}

class PhoneticTrie {
  constructor() {
    this.root = new TrieNode(); // Root node of the trie
  }

  // Insert words and phoneme clusters into the Trie
  insert(word, phonemes, weight) {
    let node = this.root;
    for (let phoneme of phonemes) {
      if (!node.children[phoneme]) {
        node.children[phoneme] = new TrieNode();
      }
      node = node.children[phoneme];
    }
    node.isWord = true;
    node.word = word;
    node.cumulativeWeight = weight; // Store the cumulative weight at this node
    node.phonemeLength = phonemes.length; // Store the length of the phoneme sequence
  }

  // Global sorting of rhyming phoneme matches using a min-heap
  searchWithPagination({ phonemes, limit, page, length }) {
    const minHeap = new MinHeap(); // Min-Heap to store sorted results

    // Generate rhyming phoneme variations
    const rhymingPhonemesArray = expandToRhymingPhonemes(phonemes);

    // Search the Trie for all rhyming phoneme sequences and insert results into the global heap
    for (let rhyme of rhymingPhonemesArray) {
      this.searchAndInsertToHeap(rhyme.phonemes, this.root, rhyme.weight, minHeap, length);
    }

    // Paginate results directly from the globally sorted heap
    const paginatedResults = [];
    const startIndex = (page - 1) * limit;
    let index = 0;

    while (minHeap.size() > 0 && paginatedResults.length < limit) {
      const wordData = minHeap.extractMin();
      if (index >= startIndex) {
        paginatedResults.push(wordData.word);
      }
      index++;
    }

    return paginatedResults;
  }

  // Search a specific phoneme sequence in the Trie and insert matches into the heap
  searchAndInsertToHeap(phonemes, node, rhymeWeight, heap, targetLength, depth = 0, phonemeSeq = []) {
    if (depth === phonemes.length) {
      if (node.isWord && node.word) {
        // If length filtering is specified, ensure the word matches the phoneme length
        if (!targetLength || node.phonemeLength === targetLength) {
          heap.insert({
            word: node.word,
            cumulativeWeight: node.cumulativeWeight + rhymeWeight, // Add the rhyme weight here
            phonemeLength: node.phonemeLength, // Include phoneme length for sorting
          });
        }
      }
      return;
    }

    const phoneme = phonemes[depth];
    if (!node.children[phoneme]) return; // No match
    this.searchAndInsertToHeap(phonemes, node.children[phoneme], rhymeWeight, heap, targetLength, depth + 1, [...phonemeSeq, phoneme]);
  }
}


class MinHeap {
  constructor() {
    this.heap = []; // Store the heap elements
  }

  // Insert an item into the heap based on its weight and phoneme length
  insert({ word, cumulativeWeight, phonemeLength }) {
    this.heap.push({ word, cumulativeWeight, phonemeLength });
    this.bubbleUp(this.heap.length - 1);
  }

  bubbleUp(index) {
    let currentIndex = index;
    while (currentIndex > 0) {
      const parentIndex = Math.floor((currentIndex - 1) / 2);
      // Sort primarily by cumulative weight, secondarily by phoneme length
      if (
        this.heap[currentIndex].cumulativeWeight > this.heap[parentIndex].cumulativeWeight ||
        (this.heap[currentIndex].cumulativeWeight === this.heap[parentIndex].cumulativeWeight &&
          this.heap[currentIndex].phonemeLength > this.heap[parentIndex].phonemeLength)
      ) {
        break;
      }
      [this.heap[currentIndex], this.heap[parentIndex]] = [this.heap[parentIndex], this.heap[currentIndex]];
      currentIndex = parentIndex;
    }
  }

  // Extract the item with the lowest weight
  extractMin() {
    if (this.heap.length === 0) return null;
    if (this.heap.length === 1) return this.heap.pop();
    const min = this.heap[0];
    this.heap[0] = this.heap.pop();
    this.bubbleDown(0);
    return min;
  }

  bubbleDown(index) {
    const lastIndex = this.heap.length - 1;
    while (true) {
      let leftChildIdx = 2 * index + 1;
      let rightChildIdx = 2 * index + 2;
      let smallestIdx = index;

      if (
        leftChildIdx <= lastIndex &&
        (this.heap[leftChildIdx].cumulativeWeight < this.heap[smallestIdx].cumulativeWeight ||
          (this.heap[leftChildIdx].cumulativeWeight === this.heap[smallestIdx].cumulativeWeight &&
            this.heap[leftChildIdx].phonemeLength < this.heap[smallestIdx].phonemeLength))
      ) {
        smallestIdx = leftChildIdx;
      }
      if (
        rightChildIdx <= lastIndex &&
        (this.heap[rightChildIdx].cumulativeWeight < this.heap[smallestIdx].cumulativeWeight ||
          (this.heap[rightChildIdx].cumulativeWeight === this.heap[smallestIdx].cumulativeWeight &&
            this.heap[rightChildIdx].phonemeLength < this.heap[smallestIdx].phonemeLength))
      ) {
        smallestIdx = rightChildIdx;
      }
      if (smallestIdx === index) break;

      [this.heap[index], this.heap[smallestIdx]] = [this.heap[smallestIdx], this.heap[index]];
      index = smallestIdx;
    }
  }

  size() {
    return this.heap.length;
  }
}

function expandToRhymingPhonemes(phonemes) {
  // todo: this function would have a huge map of phoneme substitutions...
  const phonemeSubstitutions = {
    "k": ["g", "p"], // Example of substitutions for the phoneme "k"
    "æ": ["a", "e"], // Example for "æ"
    "t": ["d", "s"], // Example for "t"
  };

  const rhymingPhonemesArray = [];

  function generateRhymes(sequence, depth = 0, currentWeight = 0) {
    if (depth === sequence.length) {
      rhymingPhonemesArray.push({ phonemes: sequence.slice(), weight: currentWeight });
      return;
    }

    const phoneme = sequence[depth];
    const substitutions = phonemeSubstitutions[phoneme] || [phoneme];

    for (const sub of substitutions) {
      generateRhymes(
        [...sequence.slice(0, depth), sub, ...sequence.slice(depth + 1)],
        depth + 1,
        currentWeight + (sub === phoneme ? 0 : 1)  // Substitution adds to weight
      );
    }
  }

  generateRhymes(phonemes);
  return rhymingPhonemesArray;
}

const trie = new PhoneticTrie();
trie.insert("glad", ["g", "l", "a", "d"], 3);
trie.insert("grad", ["g", "r", "a", "d"], 2);
trie.insert("blad", ["b", "l", "a", "d"], 4);
trie.insert("grin", ["g", "r", "i", "n"], 5);

// Search for similar words to "g-l-a-d" and paginate the results, with optional length filtering
const resultsPage1 = trie.searchWithPagination({
  phonemes: ["g", "l", "a", "d"],
  limit: 2,
  page: 1,
  length: 4, // Only consider words with exactly 4 phonemes
});

const resultsPage2 = trie.searchWithPagination({
  phonemes: ["g", "l", "a", "d"],
  limit: 2,
  page: 2,
  length: 4,
});

console.log(resultsPage1); // Output: ["grad", "glad"] (sorted by weight and length)
console.log(resultsPage2); // Output: ["blad", "grin"]

Updates

  1. Here is a gist with further expansion/attempts at the Trie to solve this key problem.

Solution

  • The fundamental problem you have is that the trie implies a global lexical ordering, but you want your output in terms of a global weight ordering.

    The trie is not helping you.

    An appropriate data structure for this problem would consist of an array of words ordered by weight in combination with an inverted index that gives, for each phonetic suffix that you might want to rhyme with, the ordered list of indexes of rhyming words.

    You can probably use an open source inverted index implementation.