Search code examples
pythonlistrecursionbinary-search-tree

How do I recursively find nodes in a BST within a range, and return a full list?


##############################
class Node:
    def __init__(self,value):
        self.left = None
        self.right = None
        self.val = value

###############################
class BinarySearchTree:
    def __init__(self):
        self.root = None

def print_tree(node): 
    if node == None:
        return
    print_tree(node.left)
    print_tree(node.right) 
    print(node.val)


#################################################
# Task 1: get_nodes_in_range function
#################################################  
def get_nodes_in_range(node,min,max):
    if node == None:
        return
    get_nodes_in_range(node.left, min, max)
    get_nodes_in_range(node.right, min, max)
    if min <= node.val <= max:
        nodelist.append(node.val)
    return nodelist
    


if __name__ == '__main__':
    BST = BinarySearchTree()
    BST.root = Node(10)
    BST.root.left = Node(5)
    BST.root.right = Node(15)
    BST.root.left.left = Node(2)
    BST.root.left.right = Node(8)
    BST.root.right.left = Node(12)
    BST.root.right.right = Node(20)
    BST.root.right.right.right = Node(25)
    nodelist = []
    print(get_nodes_in_range(BST.root, 6, 20))

my get_nodes_in_range function requires a list to be appended to. Is there a way to make this function work without creating a list outside the function? ie. directly returning a list generated recursively?

Asking as this is part of an assignment for school, and although it returns the correct output, it fails the unit test: Unexpected error: name 'nodelist' is not defined.


Solution

  • The source of the problem is you trying to use global variable nodelist inside get_nodes_in_range function. Your code works when you run it and fails when you import it. Because __name__ is not equal to '__main__' when you import the code. That's how Python works. It is expected and correct behavior.

    I suggest the following approach to fix the problem.

    # just for convinience
    def create_bst():
        ans = BinarySearchTree()
        ans.root = Node(10)
        ans.root.left = Node(5)
        ans.root.right = Node(15)
        ans.root.left.left = Node(2)
        ans.root.left.right = Node(8)
        ans.root.right.left = Node(12)
        ans.root.right.right = Node(20)
        ans.root.right.right.right = Node(25)
        return ans
    
    # fix - create a list to store results inside the function
    def get_nodes_in_range_1(node, min_value, max_value):
    
        ans = list()
        if node is None:
            return ans
        if min_value <= node.val <= max_value:
            ans.append(node.val)
        ans.extend(get_nodes_in_range_1(node.left, min_value, max_value))
        ans.extend(get_nodes_in_range_1(node.right, min_value, max_value))
    
        return ans
    
    
    if __name__ == '__main__':
    
        bst = create_bst()
    
        nodes_in_range_1 = get_nodes_in_range_1(bst.root, 6, 20)
        print(nodes_in_range_1)