Search code examples
javaalgorithmrustbinary-search-tree

Implement the recursive insert method of binary search tree in RUST


I'm learning Rust and trying to implement a simple binary search tree (actually it's rewriting the Java implementation below). Here is what I have done:

use std::cmp::Ordering;

// Node of this BST, the two generic types are key and value
struct Node<K:Ord, V> {
    key: K,
    value: V,
    left: Option<Box<Node<K, V>>>,
    right: Option<Box<Node<K, V>>>,
    number_of_nodes: i32,
}

impl<K: Ord, V> Node<K, V> {
    // Create a new node
    fn new(key: K, value: V, number_of_nodes: i32) -> Node<K, V>{
        Node {
            key,
            value,
            left: None,
            right: None,
            number_of_nodes,
        }
    }
}

struct BST<K: Ord ,V> {
    root: Option<Box<Node<K, V>>>,
}

impl<K: Ord, V> BST<K, V> {
    // Get the size of this BST
    fn size(&self) -> i32 {
        size(&self.root)
    }

    // Search for key. Update value if found, otherwise insert the new node
    fn put(&self, key: K, value: V) {
        self.root = put(&self.root, key, value)
    }
}

// Function for recursively get the size of a sub BST 
fn size<K: Ord, V>(node: &Option<Box<Node<K, V>>>) -> i32 {
    match node {
        Some(real_node) => real_node.number_of_nodes,
        None => 0,
    }
}

// Function for recursively put a new node to this BST
fn put<K: Ord, V>(node: &Option<Box<Node<K, V>>>, key: K, value: V) -> &Option<Box<Node<K, V>>>{
    match node {
        None => {
            let new_node = Some(Box::new(Node::new(key, value, 1)));
            return &new_node;
        },
        Some(real_node) => {
            match key.cmp(&real_node.key) {
                Ordering::Less => real_node.left = *put(&real_node.left, key, value),
                Ordering::Greater => real_node.right = *put(&real_node.right, key, value), 
                Ordering::Equal => real_node.value = value,
            }
            real_node.number_of_nodes = size(&real_node.right) + size(&real_node.left) + 1;
            node
        },
    }
}

But this code won't compile, at the line self.root = put(&self.root, key, value), I get an error:

mismatched types
expected enum 'Option<Box<Node<K, V>>>' found reference '&Option<Box<Node<K, V>>>'

I don't know how to fix that, I tried to change the &self parameter to self, or self.root to *self.root, but I got more errors. I'm so confused about the reference in Rust, all I wanna do is rewrite the following Java code in Rust.

public class BST<Key extends Comparable<Key>, Value>
{
    private Node root;              //root of BST

    private class Node
    {
        private Key key;            // key
        private Value val;          // associated value
        private Node right, left;   // left and right subtrees
        private int N;              // number of nodes in subtree

        public Node(Key key, Value val, int N)
        {
            this.key = key;
            this.val = val;
            this.N = N;
        }
    }

    // Returns the number of key-value pairs in this symbol table.
    public int size()
    {
        return size(root);
    }

    // Return number of key-value pairs in BST rooted at x
    private int size(Node x)
    {
        if (x == null) return 0;
        else return x.N;
    }

    public void put(Key key, Value val)
    {
        root = put(root, key, val);
    }

    private Node put(Node x, Key key, Value val)
    {
        if (x == null) return new Node(key, val, 1);
        int cmp = key.compareTo(x.key);
        if (cmp < 0) x.left = put(x.left, key, val);
        else if (cmp > 0) x.right = put(x.right, key, val);
        else x.val = val;
        x.N = size(x.left) + size(x.right) + 1;
        return x;
    }
} 

It's dead simple in Java because I don't need to handle the reference. So here is my problems:

  1. How could I fix that mismatched error?
  2. What is the proper return type of that recursive function put, the &Option<Box<Node<K, V>>> or Option<Box<Node<K, V>>>? What's the difference?
  3. Am I on the right way to rewrite this Java code? The rust-analyzer only reports this mismatched error but I don't know if it will work as I expect. And honestly I don't fully understand what am i doing when I handle the reference in rust especially when it's a reference of a struct or enum

It's hard to learn Rust because I don't have much experience in system programming language, I appreciated your guys help :)


Solution

  • The simplest option is to take a mutable reference to the node:

    impl<K: Ord, V> BST<K, V> {
        // ...
    
        fn put(&mut self, key: K, value: V) {
            put(&mut self.root, key, value)
        }
    }
    
    fn put<K: Ord, V>(node: &mut Option<Box<Node<K, V>>>, key: K, value: V) {
        match node {
            None => {
                *node = Some(Box::new(Node::new(key, value, 1)));
            }
            Some(real_node) => {
                if key < real_node.key {
                    put(&mut real_node.left, key, value);
                } else {
                    put(&mut real_node.right, key, value);
                }
    
                real_node.number_of_nodes = size(&real_node.right) + size(&real_node.left) + 1;
            }
        }
    }
    

    Alternatively you could take the node by value and return the modified node:

    impl<K: Ord, V> BST<K, V> {
        // ...
    
        pub fn put(&mut self, key: K, value: V) {
            self.root = put(self.root.take(), key, value);
        }
    }
    
    fn put<K: Ord, V>(node: Option<Box<Node<K, V>>>, key: K, value: V) -> Option<Box<Node<K, V>>> {
        match node {
            None => Some(Box::new(Node::new(key, value, 1))),
            Some(mut real_node) => {
                if key < real_node.key {
                    real_node.left = put(real_node.left, key, value);
                } else {
                    real_node.right = put(real_node.right, key, value);
                }
    
                real_node.number_of_nodes = size(&real_node.right) + size(&real_node.left) + 1;
                Some(real_node)
            }
        }
    }
    

    I suggest reading the ownership and references chapter of the Rust book if you haven't already.