Search code examples
c++templatesvariadic-templatesc++17static-assert

C++17 invoke_result and static_assert for template func param


This is purely for gaining more knowledge when doing generic programming. How can I ensure the return type of a function passed in as a template argument to another function, which can take a different number of params (0 to N).

EDIT: I am trying to use std::invoke_result and static_assert() to ensure the proper return type for registered factory methods. I've used a better example than the first one posted to give more clarity.

#include <string>
#include <memory>
#include <typeindex>
#include <typeinfo>
#include <unordered_map>

using namespace std;

class Factory final
{
public:
    template<typename My_Type, typename Func>
    static bool Register(Func func)
    {
        // issue is these two lines of code, when trying to register Derived1 and Derived2 create functions
        typename invoke_result<Func>::type result;
        static_assert(is_same<decltype(result), unique_ptr<My_Type>>::value, "Not a unique pointer to type 'My_Type'");

        bool isRegistered = false;

        if (GetCreateFunctions().end() == GetCreateFunctions().find(typeid(My_Type)))
        {
            GetCreateFunctions()[typeid(My_Type)] = reinterpret_cast<void*>(func);
            isRegistered = true;
        }

        return isRegistered;
    }

    template<typename My_Type, typename... Args>
    static unique_ptr<My_Type> Create(Args&&... args)
    {
        unique_ptr<My_Type> type = nullptr;
        auto iter = GetCreateFunctions().find(typeid(My_Type));

        if (GetCreateFunctions().end() != iter)
        {
            typedef unique_ptr<My_Type>(*create_func)(Args&&...);
            auto create = reinterpret_cast<create_func>(iter->second);
            type = create(forward<Args>(args)...);
        }

        return type;
    }

private:
    static unordered_map<type_index, void*>& GetCreateFunctions()
    {
        static unordered_map<type_index, void*> map;
        return map;
    }
};

class Base
{
public:
    Base(unique_ptr<string>&& moveString)
        :
        _moveString(move(moveString))
    {
    }

    virtual ~Base() = default;
    virtual void DoSomething() const = 0;

protected:
    unique_ptr<string> _moveString;
};

class Derived1 final : public Base
{
public:
    Derived1(unique_ptr<string>&& moveString)
        :
        Base(move(moveString))
    {
    }

    ~Derived1() = default;

    void DoSomething() const override
    {
        if (_moveString)
        {
            // do something...
        }
    }

private:
    static const bool _isRegistered;

    static unique_ptr<Derived1> Create(unique_ptr<string>&& moveString)
    {
        return make_unique<Derived1>(move(moveString));
    }
};

const bool Derived1::_isRegistered = Factory::template Register<Derived1>(&Derived1::Create);

class Derived2 final : public Base
{
public:
    Derived2()
        :
        Base(make_unique<string>("Default"))
    {
    }

    ~Derived2() = default;

    void DoSomething() const override
    {
        if (_moveString)
        {
            // do something...
        }
    }

private:
    static const bool _isRegistered;

    static unique_ptr<Derived2> Create()
    {
        return make_unique<Derived2>();
    }
};

const bool Derived2::_isRegistered = Factory::template Register<Derived2>(&Derived2::Create);


int main(int argc, char** argv)
{
    string moveString = "moveString";
    unique_ptr<Base> myBase_Derived1 = Factory::template Create<Derived1>(make_unique<string>(move(moveString)));
    unique_ptr<Base> myBase_Derived2 = Factory::template Create<Derived2>();

    if (myBase_Derived1)
        printf("Success\n");

    if (myBase_Derived2)
        printf("Success\n");

    return 0;
}

Solution

  • This doesn't work in all functional cases but, with function pointers, should work.

    If you define a custom template as follows

    template <typename>
    struct baz;
    
    template <typename R, typename ... Args>
    struct baz<R(*)(Args...)>
     { using retType = R; };
    

    your example become

    template <typename T, typename Func>
    void Example(Func func)
     {
       static_assert(std::is_same<typename baz<Func>::retType,
                                  std::unique_ptr<T>>::value,  "!");
     }
    

    The following is a full compiling example

    #include <type_traits>
    #include <memory>
    
    template <typename>
    struct baz;
    
    template <typename R, typename ... Args>
    struct baz<R(*)(Args...)>
     { using retType = R; };
    
    template <typename T, typename Func>
    void Example(Func func)
     {
       static_assert(std::is_same<typename baz<Func>::retType,
                                  std::unique_ptr<T>>::value,  "!");
     }
    
    struct Foo
     {
       Foo (int i) : My_I{i}
        { }
    
       Foo () : My_I{99}
        { }
    
       ~Foo() = default;
    
       int My_I;
     };
    
    std::unique_ptr<Foo> bar0 ()
    { return std::make_unique<Foo>(11); }
    
    std::unique_ptr<Foo> bar1 (int)
    { return std::make_unique<Foo>(11); }
    
    std::unique_ptr<Foo> bar2 (int, long)
    { return std::make_unique<Foo>(11); }
    
    int main ()
     {
       Example<Foo>(&bar0);
       Example<Foo>(&bar1);
       Example<Foo>(&bar2);
     }
    

    Unfortunately this solution works only with function pointers and not with (by example) lambda function (when not convertible to function pointers) or classes with operator().

    If you can use C++17 -- and I suppose is your case, if you use std::invoke_result -- you can also use automatic deduction guides for std::function.

    So, in C++17, you can forget the baz struct and simply write

    template <typename T, typename Func>
    void Example(Func func)
     {        
       static_assert(std::is_same<
           typename decltype(std::function{func})::result_type,
           std::unique_ptr<T>>::value,  "!");
     }
    

    or, in your Register() method

    template<typename My_Type, typename Func>
    static bool Register(Func func)
    {
        static_assert(std::is_same<
           typename decltype(std::function{func})::result_type,
           std::unique_ptr<My_Type>>::value,  "!");
    
        bool isRegistered = false;
    
        if (GetCreateFunctions().end() == GetCreateFunctions().find(typeid(My_Type)))
        {
            GetCreateFunctions()[typeid(My_Type)] = reinterpret_cast<void*>(func);
            isRegistered = true;
        }
    
        return isRegistered;
    }