Search code examples
c++templatesc++17sfinaetemplate-argument-deduction

Class Template Derived from Another Class Template Type Deduction


I am trying to write a functor memoizer to save time on repeated expensive function calls. In my class design I am struggling to find a simple interface.

Using this Functor base class:

template <typename TOut, typename TIn>
class Functor {
public:
    virtual
    ~Functor() {
    }

    virtual
    TOut operator()(TIn input) = 0;
};

I now want to write a class that will encapsulate and memoize a functor. In addition to encapsulating a Functor, the MemoizedFunctor will itself be a Functor. This results in having 3 template parameters.

Here is a working example:

#include <unordered_map>

template <typename F, typename TOut, typename TIn>
class MemoizedFunctor : public Functor<TOut, TIn> {
public:
    MemoizedFunctor(F f) : f_(f) {
    }

    virtual
    ~MemoizedFunctor() {
    }

    virtual
    TOut operator()(TIn input) override {
        if (cache_.count(input)) {
            return cache_.at(input);
        } else {
            TOut output = f_(input);
            cache_.insert({input, output});
            return output;
        }
    }

private:
    F f_;
    std::unordered_map<TIn, TOut> cache_;
};

class YEqualsX : public Functor<double, double> {
public:
    virtual
    ~YEqualsX() {
    }

    double operator()(double x) override {
        return x;
    }
};

int main() {
    MemoizedFunctor<YEqualsX, double, double> f((YEqualsX())); // MVP

    f(0); // First call
    f(0); // Cached call

    return 0;
}

I feel like there MUST be a way to eliminate having to specify all 3 template parameters. Given the function passed to the constructor of MemoizedFunctor, I would argue that all three template parameters could be deduced.

I am not sure how to rewrite the class so that using it would not require all the template specification.

I tried using a smart pointer to a Functor as the member variable in MemoizedFunctor. This eliminates the first template parameter, but now the user of the class must pass a smart pointer to the MemoizedFunctor class.

In summary, I would like to have all the template arguments of MemoizedFunctor be automatically be deduced on construction. I believe this is possible because on construction all template arguments are unambiguous.


Solution

  • In summary, I would like to have all the template arguments of MemoizedFunctor be automatically be deduced on construction. I believe this is possible because on construction all template arguments are unambiguous.

    If I understand correctly, the first template type for MemoizedFunctor is ever a Functor<TOut, TIn>, or something that inherit from some Functior<TOut, TIn>, where TOut and TIn are second and third template parameter for MemoizedFunctor.

    It seems to me that you're looking for a deduction guide.

    To deduce the second and third template parameter, I propose to declare (no definition is required because are used only inside decltype()) the following couple of functions

    template <typename TOut, typename TIn>
    constexpr TIn getIn (Functor<TOut, TIn> const &);
    
    template <typename TOut, typename TIn>
    constexpr TOut getOut (Functor<TOut, TIn> const &);
    

    Now, using decltype() and std::declval(), the user defined deduction guide simply become

    template <typename F>
    MemoizedFunctor(F)
       -> MemoizedFunctor<F,
                          decltype(getOut(std::declval<F>())),
                          decltype(getIn(std::declval<F>()))>;
    

    The following is a full compiling example

    #include <unordered_map>
    
    template <typename TOut, typename Tin>
    class Functor
     {
       public:
        virtual ~Functor ()
         { }
    
        virtual TOut operator() (Tin input) = 0;
     };
    
    template <typename TOut, typename TIn>
    constexpr TIn getIn (Functor<TOut, TIn> const &);
    
    template <typename TOut, typename TIn>
    constexpr TOut getOut (Functor<TOut, TIn> const &);
    
    template <typename F, typename TOut, typename TIn>
    class MemoizedFunctor : public Functor<TOut, TIn>
     {
       public:
          MemoizedFunctor(F f) : f_{f}
           { }
    
          virtual ~MemoizedFunctor ()
           { }
    
          virtual TOut operator() (TIn input) override
           {
             if ( cache_.count(input) )
                return cache_.at(input);
             else
              {
                TOut output = f_(input);
                cache_.insert({input, output});
                return output;
              }
           }
    
       private:
          F f_;
          std::unordered_map<TIn, TOut> cache_;
     };
    
    class YEqualsX : public Functor<double, double>
     {
       public:
          virtual ~YEqualsX ()
           { }
    
          double operator() (double x) override
           { return x; }
     };
    
    template <typename F>
    MemoizedFunctor(F)
       -> MemoizedFunctor<F,
                          decltype(getOut(std::declval<F>())),
                          decltype(getIn(std::declval<F>()))>;
    
    int main ()
     {
       MemoizedFunctor f{YEqualsX{}};
    
       f(0); // First call
       f(0); // Cached call
     }
    

    -- EDIT --

    Aschepler, in a comment, observed that there is a possible drawback in this solution: some types can't be returned from a function.

    By example, a function can't return a C-style array.

    This ins't a problem deducing TOut (the type returned by operator()) exactly because is a type returned by a method so is returnable also by getOut().

    But this can be (generally speaking) a problem for TIn: if TIn is, by example, int[4] (can't be, in this case, because is used as a key for an unordered map but, I repeat, generally speaking), a int[4] can't returned by getIn().

    You can go around this problem (1) adding a type wrapper struct as follows

    template <typename T>
    struct typeWrapper
     { using type = T; };
    

    (2) modifying getIn() to return the wrapper TIn

    template <typename TOut, typename TIn>
    constexpr typeWrapper<TIn> getIn (Functor<TOut, TIn> const &);
    

    and (3) modifying the deduction guide to extract TIn from the wrapper

    template <typename F>
    MemoizedFunctor(F)
       -> MemoizedFunctor<F,
                          decltype(getOut(std::declval<F>())),
                          typename decltype(getIn(std::declval<F>()))::type>;