Search code examples
c++templatesrecursiontemplate-specializationfactorial

Template factorial function without template specialization


I don't understand the following behavior.

The following code, aimed at computing the factorial at compile time, doesn't even compile:

#include <iostream>
using namespace std;
template<int N>
int f() {
  if (N == 1) return 1; // we exit the recursion at 1 instead of 0
  return N*f<N-1>();
}
int main() {
  cout << f<5>() << endl;
  return 0;
}

and throws the following error:

...$ g++ factorial.cpp && ./a.out 
factorial.cpp: In instantiation of ‘int f() [with int N = -894]’:
factorial.cpp:7:18:   recursively required from ‘int f() [with int N = 4]’
factorial.cpp:7:18:   required from ‘int f() [with int N = 5]’
factorial.cpp:15:16:   required from here
factorial.cpp:7:18: fatal error: template instantiation depth exceeds maximum of 900 (use ‘-ftemplate-depth=’ to increase the maximum)
    7 |   return N*f<N-1>();
      |            ~~~~~~^~
compilation terminated.

whereas, upon adding the specialization for N == 0 (which the template above doesn't even reach),

template<>
int f<0>() {
  cout << "Hello, I'm the specialization.\n";
  return 1;
}

the code compiles and give the correct output of, even if the specialization is never used,

...$ g++ factorial.cpp && ./a.out 
120

Solution

  • The issue here is that your if statement is a run time construct. When you have

    int f() {
      if (N == 1) return 1; // we exit the recursion at 1 instead of 0
      return N*f<N-1>();
    }
    

    the f<N-1> is instantiated as it may be called. Even though the if condition will stop it from calling f<0>, the compiler still has to instantiate it since it is part of the function. That means it instantiates f<4>, which instantiates f<3>, which instantiates f<2>, and on and on it will go forever.

    The Pre C++17 way to stop this is to use a specialization for 0 which breaks that chain. Starting in C++17 with constexpr if, this is no longer needed. Using

    int f() {
      if constexpr (N == 1) return 1; // we exit the recursion at 1 instead of 0
      else return N*f<N-1>();
    }
    

    guarantees that return N*f<N-1>(); won't even exist in the 1 case, so you don't keep going down the instantiation rabbit hole.