Search code examples
c++templatescastingreturnreturn-type

How do I make my function return different results depending on type it's been called using templates in C++


I've been looking for quite a while for an answer to my question, but couldn't find anything that worked.

Basically, I have a Binary Search Tree and a Search function:

T BinarySearchTree::Search(Node *tree, int key) const {
    if (tree == nullptr) // if root is null, key is not in tree
        return false;
    if (key == tree->key) // key is found
        return true;
    else if (key < tree->key) // recursively look at left subtree if key < tree->key
        return Search<T>(tree->left, key);
    else // recursively look at right subtree if key > tree->key
        return Search<T>(tree->right, key);
} 

I want to return different things based on what type it's being called with. E.g. if I call the function as Search<bool>(), I want it to return true or false, but if I call Search<Node*>, I want it to return a pointer. It should kind of look like this:

T BinarySearchTree::Search(Node *tree, int key) const {
    if (key == tree->key){  // key is found
       if(T = bool)
            return true;
       else if(T = Node*)
           return tree;
       else if (T = int)
           return tree->key;
}

I'm not even sure if templates are the right way to go here, but any tips for implementation would be appreciated.


Solution

  • What you are asking for can be done using if constexpr in C++17 and later, eg:

    #include <type_traits>
    
    template<typename T>
    T BinarySearchTree::Search(Node *tree, int key) const {
    
        static_assert(
            std::is_same_v<T, bool> ||
            std::is_same_v<T, Node*> ||
            std::is_same_v<T, int>,
            "Invalid type specified");
    
        if (tree == nullptr) {
            if constexpr (std::is_same_v<T, bool>)
                return false;
            }
            else if constexpr (std::is_same_v<T, Node*>) {
                return nullptr;
            }
            else {
                return 0; // or throw an exception
            }
        }
        else if (key == tree->key) {
            if constexpr (std::is_same_v<T, bool>)
                return true;
            }
            else if constexpr (std::is_same_v<T, Node*>) {
                return tree;
            }
            else {
                return tree->key;
            }
        }
        else if (key < tree->key)
            return Search<T>(tree->left, key);
        else
            return Search<T>(tree->right, key);
    }
    

    Prior to C++17, a similar result can be accomplished using SFINAE, eg:

    #include <type_traits>
    
    Node* BinarySearchTree::InternalSearch(Node *tree, int key) const {
        if (tree == nullptr) {
            return nullptr;
        }
        else if (key == tree->key) {
            return tree;
        }
        else if (key < tree->key)
            return Search<T>(tree->left, key);
        else
            return Search<T>(tree->right, key);
    }
    
    template<typename T>
    typename std::enable_if<std::is_same<T, bool>::value, T>::type
    BinarySearchTree::Search(Node *tree, int key) const {
        return InternalSearch(tree, key) != nullptr;
    }
    
    template<typename T>
    typename std::enable_if<std::is_same<T, Node*>::value, T>::type
    BinarySearchTree::Search(Node *tree, int key) const {
       return InternalSearch(tree, key);
    }
    
    template<typename T>
    typename std::enable_if<std::is_same<T, int>::value, T>::type
    BinarySearchTree::Search(Node *tree, int key) const {
        tree = InternalSearch(tree, key);
        return tree != nullptr ? tree->key : 0 /* or throw an exception */;
    } 
    

    Or, using template specialization, eg:

    Node* BinarySearchTree::InternalSearch(Node *tree, int key) const {
        if (tree == nullptr) {
            return nullptr;
        }
        else if (key == tree->key) {
            return tree;
        }
        else if (key < tree->key)
            return Search<T>(tree->left, key);
        else
            return Search<T>(tree->right, key);
    }
    
    template<typename T>
    T BinarySearchTree::Search(Node *tree, int key) const {
        // throw a suitable exception
    }
    
    template<>
    bool BinarySearchTree::Search<bool>(Node *tree, int key) const {
        return InternalSearch(tree, key) != nullptr;
    }
    
    template<>
    Node* BinarySearchTree::Search<Node*>(Node *tree, int key) const {
        return InternalSearch(tree, key);
    }
    
    template<>
    int BinarySearchTree::Search<int>(Node *tree, int key) const {
        tree = InternalSearch(tree, key);
        return tree != nullptr ? tree->key : 0 /* or throw an exception */;
    }