Search code examples
c++templatesvariadic-functionstype-traitsvariadic

C++ return type depending on the number of function arguments


I have the following struct:

#define vec std::vector
struct A
{
  std::mt19937 rng;
  std::uniform_real_distribution<double> U;
  A(){}
  A(int sed) 
  {
    rng.seed(sed); 
    U = std::uniform_real_distribution<double>(0, 100000);
  }
  
  
  template <typename T>
  vec<T> get(std::size_t size)
  {
    vec<T> rst(size);
    for (auto &x: rst) x = U(rng);
    return rst;
  }
  
  
  template <typename T>
  vec<vec<T>> get(std::size_t size1, std::size_t size2)
  {
    vec<vec<T>> rst(size1);
    for (auto &x: rst) get(size2).swap(x);
    return rst;
  } 
  
  
  template <typename T>
  vec<vec<vec<T>>> get(std::size_t size1, std::size_t size2, std::size_t size3)
  {
    vec<vec<vec<T>>> rst(size1);
    for (auto &x: rst) get(size2, size3).swap(x);
    return rst;
  }
};

#undef vec

I know this can be a long shot, but how to write a member function magicGet() such that when I do:

auto u = magicGet<T>(3, 1, 2, 5);
auto v = magicGet<T>(7, 9, 6, 2, 2);
auto w = magicGet<T>(6);

I will obtain u in type vec<vec<vec<vec<T>>>>, v in type vec<vec<vec<vec<vec<T>>>>>, and w in type vec<T>, etc.?

If it's impossible, what would be the closest solution?

Update: by absorbing the accepted answer and the post from @Shreeyash Shrestha, the simplest solution might be:

  template <typename T, typename... Args>
  auto magicGet(std::size_t size, Args... args)
  {
    if constexpr (sizeof...(args) == 0)
    {
      vec<T> rst(size);
      for (auto &x: rst) x = U(rng);
      return rst;
    }
    else // The code body must be wrapped inside else {} to ensure the compiler
      // knows they are mutually exclusive.
    {
      vec<decltype(magicGet<T>(args...))> rst(size);
      for (auto &x: rst) magicGet<T> (args...).swap(x);
      return rst;
    }
  } 

Solution

  • Try this:

      template<typename T>
      std::vector<T> magic_func(T value) {
        static_assert(std::is_convertible_v<decltype(value), size_t>);
        std::vector<T> result(static_cast<size_t>(value));
    
        for (auto& val : result) {
          val = U(rng);
        }
    
        return result;
      }
    
      template<typename T, typename... Args>
      auto magic_func(T value, Args... args) {
        static_assert(std::is_convertible_v<decltype(value), size_t>);
        std::vector<decltype(magic_func(args...))> result(static_cast<size_t>(value));
        for (auto& sub_vec : result) {
          magic_func(args...).swap(sub_vec);
        }
        return result;
      }
    

    Here is the full code and test code:

    #include <random>
    #include <vector>
    #include <iostream>
    
    #define vec std::vector
    
    struct A {
      // Your code ...
    
      template<typename T>
      std::vector<T> magic_func(T value) {
        static_assert(std::is_convertible_v<decltype(value), size_t>);
        std::vector<T> result(static_cast<size_t>(value));
    
        for (auto& val : result) {
          val = U(rng);
        }
    
        return result;
      }
    
      template<typename T, typename... Args>
      auto magic_func(T value, Args... args) {
        static_assert(std::is_convertible_v<decltype(value), size_t>);
        std::vector<decltype(magic_func(args...))> result(static_cast<size_t>(value));
        for (auto& sub_vec : result) {
          magic_func(args...).swap(sub_vec);
        }
        return result;
      }
    
    };
    
    #undef vec
    
    int main() {
      A a;
      auto result1 = a.magic_func<int>(1);
      static_assert(
          std::is_same_v<decltype(result1),
                         std::vector<int>>
      );
      auto result2 = a.magic_func<int>(1, 2);
      static_assert(
          std::is_same_v<decltype(result2),
                         std::vector<std::vector<int>>>
      );
      auto result3 = a.magic_func<int>(1, 2, 3);
      static_assert(
          std::is_same_v<decltype(result3),
                         std::vector<std::vector<std::vector<int>>>>
      );
    }