Search code examples
pythondata-structurestrie

Passing a list of strings to be put into trie


I have the code that can build a trie data structure when it is given one string. When I am trying to pass a list of strings, it combines the words into one

class TrieNode:
    def __init__(self):
        self.end = False
        self.children = {}

    def all_words(self, prefix):
        if self.end:
            yield prefix

        for letter, child in self.children.items():
            yield from child.all_words(prefix + letter)

class Trie:
    def __init__(self):
        self.root = TrieNode()
    def __init__(self):
        self.root = TrieNode()

    def insert(self, words):
        curr = self.root
        #the line I added to read the words from a list is below
        for word in words:
            for letter in word:
                node = curr.children.get(letter)
                if not node:
                    node = TrieNode()
                    curr.children[letter] = node
                curr = node
            curr.end  = True


    def all_words_beginning_with_prefix(self, prefix):
        cur = self.root
        for c in prefix:
            cur = cur.children.get(c)
            if cur is None:
                return  # No words with given prefix

        yield from cur.all_words(prefix)

This is the code I use to insert everything into the tree:

lst = ['foo', 'foob', 'foobar', 'foof']
trie = Trie()
trie.insert(lst)

The output I get is

['foo', 'foofoob', 'foofoobfoobar', 'foofoobfoobarfoof']

The output I would like to get is

['foo', 'foob', 'foobar', 'foof']

This is the line I used to get the output (for reproducibility, in case you will need to run the code) - it returns all the words that start with a particular prefix:

print(list(trie.all_words_beginning_with_prefix('foo')))

How do I fix it?


Solution

  • You aren't resetting curr back to the root after each insert, so you're inserting the next word where the last one left off. You'd want something like:

    def insert(self, words):
        curr = self.root
        for word in words:
            for letter in word:
                node = curr.children.get(letter)
                if not node:
                    node = TrieNode()
                    curr.children[letter] = node
                curr = node
            curr.end  = True
            curr = self.root  # Reset back to the root
    

    I'd break this up though. I think your insert function is doing too much, and shouldn't be dealing with multiple strings. I'd change it to something like:

    def insert(self, word):
        curr = self.root
        for letter in word:
            node = curr.children.get(letter)
            if not node:
                node = TrieNode()
                curr.children[letter] = node
            curr = node
        curr.end  = True
    
    def insert_many(self, words):
        for word in words:
            self.insert(word)  # Just loop over self.insert
    

    Now that's a non-problem since each insert is an independent call, and you can't forget to reset curr.