Search code examples
c++pytorchlibtorch

Common class for Linear, Conv1d, Conv2d,..., LSTM,


Is there any class that all torch::nn::Linear, torch::nn::Conv1d, torch::nn::Conv2d, ... torch::nn::GRU, .... all inherit from that? torch::nn::Module seems be a good option, though there is a middle class, called torch::nn::Cloneable, so that torch::nn::Module does not work. Also, torch::nn::Cloneable itself is a template so that needs type in the declaration. I want to create a general class model, which has std::vector<the common class> layers, so that later I can fill layers with any type of layer that I want, e.g., Linear, LSTM, etc. Is there such a capability in the current API? This can be done easily in python, though here we need declaration and this hinders the python's easiness.

Thanks, Afshin


Solution

  • I found that nn::sequential can be used for a this purpose, and it does not need a forward implementation, which can be a positive point and at a same time a negative point. nn::sequential already requires each module to have a forward implementation, and calls the forward functions in a sequence that they have added in. So, one cannot create an ad-hock non-usual forward pass like Dense-Net with that, though it is good enough for general usages.

    In addition, it seems that nn::sequential just uses a std::vector<nn::AnyModule> as its underlying module list. So, std::vector<nn::AnyModule> also might be used.