Consider the following words:
'a', 'ab', 'abcd', 'b', 'bcd'
Adding them to a Trie will result in the following representation with the stars meaning that a node is an endword:
root
/ \
*a *b
/ \
*b c
/ \
c *d
/
*d
In this example we have two paths and the maximum number of end words in any path is 3(a, ab, abcd). How would you perform the DFS to get the max?
Here is my code for a Trie:
class TrieNode:
def __init__(self):
self.children = dict()
self.end_word = 0
class Trie:
def __init__(self):
self.root = TrieNode()
def insert(self, key):
current = self.root
for char in key:
if char not in current.children:
current.children[char] = TrieNode()
current = current.children[char]
current.end_word += 1
You should add a method in your TrieNode
, if I understood well your question, you want this trie :
root
/ \
*a *b
/ \
*b c
/ \ \
c *d *d
/ /
*d *e
To return 4 (a, ab, abd, abde)
You can do it recursively:
class TrieNode:
def __init__(self):
self.children = dict()
self.end_word = 0
def count_end_words(self):
if self.children:
return self.end_word + max(child.count_end_words() for child in self.children.values())
return self.end_word
class Trie:
def __init__(self):
self.root = TrieNode()
def insert(self, key):
current = self.root
for char in key:
if char not in current.children:
current.children[char] = TrieNode()
current = current.children[char]
current.end_word += 1
def max_path_count_end_words(self):
return self.root.count_end_words()
root = Trie()
for word in ('a', 'ab', 'abcd', 'b', 'bcd', 'abd', 'abde'):
root.insert(word)
print(root.max_path_count_end_words()) # returns 4
As mentionned in the comment, you can avoid creating a class TrieNode
, this is a way to do it:
class Trie:
def __init__(self):
self.children = dict()
self.is_end_word = False
def insert(self, key):
current = self
if not key:
return
if len(key) == 1:
self.is_end_word = True
char = key[0]
if char not in current.children:
self.children[char] = Trie()
return self.children[char].insert(key[1:])
def max_path_count_end_words(self):
if self.children:
return self.is_end_word + max(child.max_path_count_end_words() for child in self.children.values())
return self.is_end_word
root = Trie()
for word in ('a', 'ab', 'abcd', 'b', 'bcd', 'abd', 'abde'):
root.insert(word)
print(root.max_path_count_end_words()) # returns 4