Search code examples
c++pytorchlibtorch

libtorch (PyTorch C++) weird class syntax


In the official PyTorch C++ examples on GitHub Here you can witness a strange definition of a class:

class CustomDataset : public torch::data::datasets::Dataset<CustomDataset> {...}

My understanding is that this defines a class CustomDataset which "inherits from" or "extends" torch::data::datasets::Dataset<CustomDataset>. This is weird to me since the class we're creating is inheriting from another class which is parameterized by the class we're creating...How does this even work? What does it mean? This seems to me like an Integer class inheriting from vector<Integer>, which seems absurd.


Solution

  • This is the curiously-recurring template pattern, or CRTP for short. A major advantage of this technique is that it enabled so-called static polymorphism, meaning that functions in torch::data::datasets::Dataset can call into functions of CustomDataset, without needing to make those functions virtual (and thus deal with the runtime mess of virtual method dispatch and so on). You can also perform compile-time metaprogramming such as compile-time enable_ifs depending on the properties of the custom dataset type.

    In the case of PyTorch, BaseDataset (the superclass of Dataset) uses this technique heavily to support operations such as mapping and filtering:

      template <typename TransformType>
      MapDataset<Self, TransformType> map(TransformType transform) & {
        return datasets::map(static_cast<Self&>(*this), std::move(transform));
      }
    

    Note the static cast of this to the derived type (legal as long as CRTP is properly applied); datasets::map constructs a MapDataset object which is also parametrized by the dataset type, allowing the MapDataset implementation to statically call methods such as get_batch (or encounter a compile-time error if they do not exist).

    Furthermore, since MapDataset receives the custom dataset type as a type parameter, compile-time metaprogramming is possible:

      /// The implementation of `get_batch()` for the stateless case, which simply
      /// applies the transform to the output of `get_batch()` from the dataset.
      template <
          typename D = SourceDataset,
          typename = torch::disable_if_t<D::is_stateful>>
      OutputBatchType get_batch_impl(BatchRequestType indices) {
        return transform_.apply_batch(dataset_.get_batch(std::move(indices)));
      }
    
      /// The implementation of `get_batch()` for the stateful case. Here, we follow
      /// the semantics of `Optional.map()` in many functional languages, which
      /// applies a transformation to the optional's content when the optional
      /// contains a value, and returns a new optional (of a different type)  if the
      /// original optional returned by `get_batch()` was empty.
      template <typename D = SourceDataset>
      torch::enable_if_t<D::is_stateful, OutputBatchType> get_batch_impl(
          BatchRequestType indices) {
        if (auto batch = dataset_.get_batch(std::move(indices))) {
          return transform_.apply_batch(std::move(*batch));
        }
        return nullopt;
      }
    

    Notice that the conditional enable is dependent on SourceDataset, which we only have available because the dataset is parametrized with this CRTP pattern.