Search code examples
rusttreetraits

why can rust directly use == to check two tree?


Question: check whether the two binary trees are the same.

My solution: use DFS.

But

https://leetcode.com/problems/same-tree/discuss/301998/Rust-One-Line-Solution

In this solution

// Definition for a binary tree node.
// #[derive(Debug, PartialEq, Eq)]
// pub struct TreeNode {
//   pub val: i32,
//   pub left: Option<Rc<RefCell<TreeNode>>>,
//   pub right: Option<Rc<RefCell<TreeNode>>>,
// }
// 
// impl TreeNode {
//   #[inline]
//   pub fn new(val: i32) -> Self {
//     TreeNode {
//       val,
//       left: None,
//       right: None
//     }
//   }
// }
use std::rc::Rc;
use std::cell::RefCell;
impl Solution {
    pub fn is_same_tree(p: Option<Rc<RefCell<TreeNode>>>,
                        q: Option<Rc<RefCell<TreeNode>>>) -> bool {
        p == q
    }
}

how does rust generate the Eq like this?


Solution

  • Using implicit recursive DFS. You can observe the result of macros (including #[derive()] macros) using cargo-expand, or Tools->Expand Macros in the playground. The #[derive(PartialEq)] in your example outputs:

    impl ::core::marker::StructuralPartialEq for TreeNode {}
    #[automatically_derived]
    #[allow(unused_qualifications)]
    impl ::core::cmp::PartialEq for TreeNode {
        #[inline]
        fn eq(&self, other: &TreeNode) -> bool {
            match *other {
                Self {
                    val: ref __self_1_0,
                    left: ref __self_1_1,
                    right: ref __self_1_2,
                } => match *self {
                    Self {
                        val: ref __self_0_0,
                        left: ref __self_0_1,
                        right: ref __self_0_2,
                    } => {
                        (*__self_0_0) == (*__self_1_0)
                            && (*__self_0_1) == (*__self_1_1)
                            && (*__self_0_2) == (*__self_1_2)
                    }
                },
            }
        }
        #[inline]
        fn ne(&self, other: &TreeNode) -> bool {
            match *other {
                Self {
                    val: ref __self_1_0,
                    left: ref __self_1_1,
                    right: ref __self_1_2,
                } => match *self {
                    Self {
                        val: ref __self_0_0,
                        left: ref __self_0_1,
                        right: ref __self_0_2,
                    } => {
                        (*__self_0_0) != (*__self_1_0)
                            || (*__self_0_1) != (*__self_1_1)
                            || (*__self_0_2) != (*__self_1_2)
                    }
                },
            }
        }
    }
    

    That might be daunting, but is actually the same as:

    impl PartialEq for TreeNode {
        fn eq(&self, other: &TreeNode) -> bool {
            self.val == other.val && self.left == other.left && self.right == other.right
        }
    }
    

    So it first compares the value, and then recurses into left, which will again compare its value... until we finish the left and start with the rights. A DFS. It may not be the same as what you wrote in hand because it is recursive and may blow the stack, but it still works.