Search code examples
data-structuresrustgraphbinary-tree

Building clean and flexible binary trees in Rust


I'm using binary trees to create a simple computation graph. I understand that linked lists are a pain in Rust, but it's a very convenient data structure for what I'm doing. I tried using Box and Rc<RefCell> for the children nodes, but it didn't work out how I wanted, so I used unsafe:

use std::ops::{Add, Mul};

#[derive(Debug, Copy, Clone)]
struct MyStruct {
    value: i32,
    lchild: Option<*mut MyStruct>,
    rchild: Option<*mut MyStruct>,
}

impl MyStruct {
    unsafe fn print_tree(&mut self, set_to_zero: bool) {
        if set_to_zero {
            self.value = 0;
        }
        println!("{:?}", self);
    
        let mut nodes = vec![self.lchild, self.rchild];
        while nodes.len() > 0 {
            let child;
            match nodes.pop() {
                Some(popped_child) => child = popped_child.unwrap(),
                None => continue,
            }

            if set_to_zero {
                (*child).value = 0;
            }
            println!("{:?}", *child);
            
            if !(*child).lchild.is_none() {
                nodes.push((*child).lchild);
            }
            if !(*child).rchild.is_none() {
                nodes.push((*child).rchild);
            }
        }
        
        println!("");
    }
}

impl Add for MyStruct {
    type Output = Self;
    fn add(self, other: Self) -> MyStruct {
        MyStruct{
            value: self.value + other.value,
            lchild: Some(&self as *const _ as *mut _),
            rchild: Some(&other as *const _ as *mut _),
        }
    }
}

impl Mul for MyStruct {
   type Output = Self;
   fn mul(self, other: Self) -> Self {
        MyStruct{
            value: self.value * other.value,
            lchild: Some(&self as *const _ as *mut _),
            rchild: Some(&other as *const _ as *mut _),
        }
   }
}

fn main() {
    let mut tree: MyStruct;
    
    {
        let a = MyStruct{ value: 10, lchild: None, rchild: None };
        let b = MyStruct{ value: 20, lchild: None, rchild: None };
        
        let c = a + b;
        println!("c.value: {}", c.value); // 30
        
        let mut d = a + b;
        println!("d.value: {}", d.value); // 30
        
        d.value = 40;
        println!("d.value: {}", d.value); // 40
        
        let mut e = c * d;
        println!("e.value: {}", e.value); // 1200
        
        unsafe {
            e.print_tree(false); // correct values
            e.print_tree(true); // all zeros
            e.print_tree(false); // all zeros, everything is set correctly
        }
        
        tree = e;
    }
    
    unsafe { tree.print_tree(false); } // same here, only zeros
}

Link to the playground

I honestly don't mind that much using unsafe, but is there a safe way doing it? How bad is the use of unsafe here?


Solution

  • You can just box both of the children, since you have a unidirectional tree:

    use std::ops::{Add, Mul};
    use std::fmt;
    
    #[derive(Clone)]
    struct MyStruct {
        value: i32,
        lchild: Option<Box<MyStruct>>,
        rchild: Option<Box<MyStruct>>,
    }
    
    impl fmt::Debug for MyStruct {
        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
            f.debug_struct("MyStruct")
                .field("value", &self.value)
                .field("lchild", &self.lchild.as_deref())
                .field("rchild", &self.rchild.as_deref())
                .finish()
        }
    }
    
    impl MyStruct {
        fn print_tree(&mut self, set_to_zero: bool) {
            if set_to_zero {
                self.value = 0;
            }
    
            println!("MyStruct {{ value: {:?}, lchild: {:?}, rchild: {:?} }}", self.value, &self.lchild as *const _, &self.rchild as *const _);
    
            if let Some(child) = &mut self.lchild {
                child.print_tree(set_to_zero);
            }
    
            if let Some(child) = &mut self.rchild {
                child.print_tree(set_to_zero);
            }
        }
    }
    
    impl Add for MyStruct {
        type Output = Self;
        fn add(self, other: Self) -> MyStruct {
            MyStruct {
                value: self.value + other.value,
                lchild: Some(Box::new(self)),
                rchild: Some(Box::new(other)),
            }
        }
    }
    
    impl Mul for MyStruct {
        type Output = Self;
        fn mul(self, other: Self) -> Self {
            MyStruct {
                value: self.value * other.value,
                lchild: Some(Box::new(self)),
                rchild: Some(Box::new(other)),
            }
        }
    }
    
    fn main() {
        let tree = {
            let a = MyStruct {
                value: 10,
                lchild: None,
                rchild: None,
            };
            let b = MyStruct {
                value: 20,
                lchild: None,
                rchild: None,
            };
    
            let c = a.clone() + b.clone();
            println!("c.value: {}", c.value); // 30
    
            let mut d = a.clone() + b.clone();
            println!("d.value: {}", d.value); // 30
    
            d.value = 40;
            println!("d.value: {}", d.value); // 40
    
            let mut e = c * d;
            println!("e.value: {}", e.value); // 1200
            
            println!("");
    
            e.print_tree(false); // correct values
            println!("");
            e.print_tree(true); // all zeros
            println!("");
            e.print_tree(false); // all zeros, everything is set correctly
            println!("");
    
            e
        };
    
        dbg!(tree);
    }
    

    I implemented Debug manually and reimplemented print_tree recursively. I don't know if there is a way to implement print_tree as mutable like that without recursion, but it's certainly possible if you take &self instead (removing the set_to_zero stuff).

    playground

    Edit: Turns out it is possible to mutably iterate over the tree values without recursion. The following code is derived from the playground in this comment by @Shepmaster.

    impl MyStruct {
        fn zero_tree(&mut self) {
            let mut node_stack = vec![self];
            let mut value_stack = vec![];
    
            // collect mutable references to each value
            while let Some(MyStruct { value, lchild, rchild }) = node_stack.pop() {
                value_stack.push(value);
    
                if let Some(child) = lchild {
                    node_stack.push(child);
                }
                if let Some(child) = rchild {
                    node_stack.push(child);
                }
            }
    
            // iterate over mutable references to values
            for value in value_stack {
                *value = 0;
            }
        }
    }