Search code examples
pythonunit-testingproperty-based-testingpython-hypothesis

Generate valid binary search tree with Python hypothesis by paramertizing recursive calls


How do you parametrize recursive strategies in the Python hypothesis library?

I'd like to test that the is_valid_bst function works by generating valid BSTs with a recursive strategy.

import hypothesis as hp
from hypothesis import strategies as hps


class TreeNode:
  def __init__(self, x):
    self.val = x
    self.left = None
    self.right = None

  def __repr__(self):
    if not self.left and not self.right:
      return f'TreeNode({self.val})'
    return f'TreeNode({self.val}, left={self.left}, right={self.right}'


def is_valid_bst(node):
  if not node:
    return True

  is_valid = True
  if node.left:
    is_valid = is_valid and node.val > node.left.val
  if node.right:
    is_valid = is_valid and node.val < node.right.val

  if not is_valid:
    return False

  return is_valid_bst(node.left) and is_valid_bst(node.right)


@hps.composite
def valid_bst_trees(draw, strategy=None, min_value=None, max_value=None):
  val = draw(hps.integers(min_value=min_value, max_value=max_value))
  node = TreeNode(val)
  node.left = draw(strategy)
  node.right = draw(strategy)
  return node


def gen_bst(tree_strategy, min_value=None, max_value=None):
  return hps.integers(min_value=min_value, max_value=max_value).flatmap(
      lambda val: valid_bst_trees(
          strategy=tree_strategy, min_value=min_value, max_value=max_value))


@hp.given(hps.recursive(hps.just(None), gen_bst))
def test_is_valid_bst_works(node):
  assert is_valid_bst(node)

Solution

  • I figured it out. My main misunderstanding was:

    • The tree_strategy created by the hypothesis.recursive strategy is safe to draw from multiple times and will generate appropriate recursion.

    A few other gotchas:

    • The base case needs both None and a singleton tree. With only None, you'll only generate None.
    • For the singleton tree, you must generate a new tree every time. Otherwise, you'll end up with cycles in the tree since each node is the same tree. Easiest way to accomplish this is hps.just(-111).map(TreeNode).
    • You'll need to overwrite the base case if it's a singleton tree to respect min_value and max_value.

    Full working solution:

    import hypothesis as hp
    from hypothesis import strategies as hps
    
    
    class TreeNode:
      def __init__(self, x):
        self.val = x
        self.left = None
        self.right = None
    
      def __repr__(self):
        if not self.left and not self.right:
          return f'TreeNode({self.val})'
        return f'TreeNode({self.val}, left={self.left}, right={self.right}'
    
    
    def is_valid_bst(node):
      if not node:
        return True
    
      is_valid = True
      if node.left:
        is_valid = is_valid and node.val > node.left.val
      if node.right:
        is_valid = is_valid and node.val < node.right.val
    
      if not is_valid:
        return False
    
      return is_valid_bst(node.left) and is_valid_bst(node.right)
    
    
    @hps.composite
    def valid_bst_trees(
        draw, tree_strategy, min_value=None, max_value=None):
      """Returns a valid BST.
    
      Idea is to pick an integer VAL in [min_value, max_value) for this tree and
      and use it as a constraint for the children by parameterizing
      `tree_strategy` so that:
    
      1. The left child value is in [min_value, VAL).
      2. The right child value is in (VAL, min_value].
      """
      # We're drawing either a None or a singleton tree.
      node = draw(tree_strategy)
      if not node:
        return None
      # Can't use implicit boolean because the values might be falsey, e.g. 0.
      if min_value is not None and max_value is not None and min_value >= max_value:
        return None
    
      # Overwrite singleton tree.val with one that respects min and max value.
      val = draw(hps.integers(min_value=min_value, max_value=max_value))
      node.val = val
    
      node.left = draw(valid_bst_trees(
          tree_strategy=tree_strategy,
          min_value=min_value,
          max_value=node.val - 1))
      node.right = draw(valid_bst_trees(
          tree_strategy=tree_strategy,
          min_value=node.val + 1,
          max_value=max_value))
      return node
    
    
    def gen_bst(tree_strategy, min_value=None, max_value=None):
      return valid_bst_trees(
          tree_strategy=tree_strategy,
          min_value=min_value,
          max_value=max_value)
    
    
    # Return a new, distinct tree node every time to avoid self referential trees.
    singleton_tree = hps.just(-111).map(TreeNode)
    
    
    @hp.given(hps.recursive(hps.just(None) | singleton_tree, gen_bst))
    def test_is_valid_bst_works(node):
      assert is_valid_bst(node)
    
    
    # Simple tests to demonstrate how the TreeNode works
    def test_is_valid_bst():
      assert is_valid_bst(None)
      assert is_valid_bst(TreeNode(1))
    
      node1 = TreeNode(1)
      node1.left = TreeNode(0)
      assert is_valid_bst(node1)
    
      node2 = TreeNode(1)
      node2.left = TreeNode(1)
      assert not is_valid_bst(node2)
    
      node3 = TreeNode(1)
      node3.left = TreeNode(0)
      node3.right = TreeNode(1)
      assert not is_valid_bst(node3)