Search code examples
algorithmscalafunctional-programmingstring-matchingaho-corasick

How can I speed up my Aho-Corasick Algorithm?


I am trying to solve a problem on HackerRank; "Determining DNA Health." I decided after looking at some of the discussions that the Aho-Corasick algorithm would be the best choice. The problem involves searching a string for various sequences with an associated value. The task is to take a subsection of these sequence value pairs from the given list and find the value associated with an input string. This is meant to be done 44850 times with a list of 100000 sequence value pairs. I have implemented the algorithm and while it is a whole lot faster than my first attempt, it still isn't fast enough to pass this test case. Here's my implementation:

Building the trie:

def createValueTrie(gs: Array[(String, Int)]): TrieNodeWithVal = {
def recurse(genes: Array[(String, Int)]): Map[Char, TrieNodeWithVal] = {
  genes
    .groupBy(_._1.head)
    .map(x => (x._1, x._2.map(y => (y._1.tail, y._2))))
    .map{
      case (c, arr: Array[(String, Int)]) => {
        val value = arr.filter(_._1.length == 0).foldLeft(0)(_ + _._2)
        val filtered = arr.filter(_._1.length > 0)
        val recursed = recurse(filtered)
        (c, new TrieNodeWithVal(arr.exists(_._1.length == 0), recursed, value))
      }
    }
  }
  new TrieNodeWithVal(false, recurse(gs), 0)
}

Searching through the trie:

def findValueMatches(trie: TrieNodeWithVal, sequence: String): Iterator[(String, Long)] = {
    sequence.scanRight("")(_ + _).dropRight(1).iterator.flatMap(s => {
      Iterator.iterate[(Iterator[Char], Option[TrieNodeWithVal])]((s.iterator, Some(trie))) {
        case (it: Iterator[Char], Some(node)) => if (it.hasNext) (it, node(it.next())) else (it, None)
        case (it: Iterator[Char], None) => (it, None)
      }.takeWhile {
        case (_, Some(_)) => true
        case _ => false
      }.map {
        case (_, Some(node)) => node
      }.zipWithIndex.withFilter {
        case (node, _) => node isWord
      }.map {
        case (node, i) => (s.slice(0, i), node.value)
      }
    })
  }

Trie node classes:

class TrieNode(isAWord: Boolean, childs: Map[Char, TrieNode]) {
    val isWord = isAWord
    val children: Map[Char, TrieNode] = childs

    def apply(c: Char): Option[TrieNode] = children.get(c)

    override def toString(): String = "(" + children.map(x => (if (x._2.isWord) x._1.toUpper else x._1) + ": " + x._2.toString()).mkString(", ") + ")"
  }

  class TrieNodeWithVal(isAWord: Boolean, childs: Map[Char, TrieNodeWithVal], valu: Long) extends TrieNode(isAWord, childs) {
    val value = valu
    override val children: Map[Char, TrieNodeWithVal] = childs

    override def toString(): String = "(" + children.map(x => (if (x._2.isWord) x._1.toUpper + "[" + x._2.value + "]" else x._1) + ": " + x._2.toString()).mkString(", ") + ")"

    override def apply(c: Char): Option[TrieNodeWithVal] = children.get(c)
  }

I know there is some more edge-building that can be done here for failure cases but several people in the discussion said that it would be slower as the trie needs to be rebuilt for each query. Are there some more efficient collections I should be using for a problem like this? How can I speed it up while maintaing a purely functional style?


Solution

  • There are various changes, some might affect performance and others are just cosmetic.

    In recurse you can combine the two map calls and use partition to reduce the number of times you test the array:

    def recurse(genes: Array[(String, Int)]): Map[Char, TrieNodeWithVal] = {
      genes
        .groupBy(_._1.head)
        .map { x =>
          val c = x._1
          val arr = x._2.map(y => (y._1.tail, y._2))
    
          val (filtered, nonFiltered) = arr.partition(_._1.nonEmpty)
          val value = nonFiltered.foldLeft(0)(_ + _._2)
          val recursed = recurse(filtered)
          (c, new TrieNodeWithVal(nonFiltered.nonEmpty, recursed, value))
        }
    }
    

    You can simplify findValueMatches by using conditions on case statements and combining some operations:

    def findValueMatches(trie: TrieNodeWithVal, sequence: String): Iterator[(String, Long)] = {
      sequence.scanRight("")(_ + _).dropRight(1).iterator.flatMap(s => {
        Iterator.iterate[(Iterator[Char], Option[TrieNodeWithVal])]((s.iterator, Some(trie))) {
          case (it: Iterator[Char], Some(node)) if it.hasNext => (it, node(it.next()))
          case (it: Iterator[Char], _) => (it, None)
        }.takeWhile {
          _._2.nonEmpty
        }.zipWithIndex.collect {
          case ((_, Some(node)), i) if node.isWord =>
           (s.slice(0, i), node.value)
        }
      })
    }
    

    Finally, you can simplify your classes by using val parameters

    class TrieNode(val isWord: Boolean, val children: Map[Char, TrieNode]) {
      def apply(c: Char): Option[TrieNode] = children.get(c)
    
      override def toString(): String = "(" + children.map(x => (if (x._2.isWord) x._1.toUpper else x._1) + ": " + x._2.toString()).mkString(", ") + ")"
    }
    
    class TrieNodeWithVal(isAWord: Boolean, childs: Map[Char, TrieNodeWithVal], val value: Long) extends TrieNode(isAWord, childs) {
      override val children: Map[Char, TrieNodeWithVal] = childs
    
      override def toString(): String = "(" + children.map(x => (if (x._2.isWord) x._1.toUpper + "[" + x._2.value + "]" else x._1) + ": " + x._2.toString()).mkString(", ") + ")"
    
      override def apply(c: Char): Option[TrieNodeWithVal] = children.get(c)
    }
    

    This is all compiled but not tested so apologies if I have inadvertently changed the algorithm.