Search code examples
c++pytorchlibtorch

Custom submodules in pytorch / libtorch C++


Full disclosure, I asked this same question on the PyTorch forums about a few days ago and got no reply, so this is technically a repost, but I believe it's still a good question, because I've been unable to find an answer anywhere online. Here goes:

Can you show an example of using register_module with a custom module? The only examples I’ve found online are registering linear layers or convolutional layers as the submodules.

I tried to write my own module and register it with another module and I couldn’t get it to work. My IDE is telling me no instance of overloaded function "MyModel::register_module" matches the argument list -- argument types are: (const char [14], TreeEmbedding)

(TreeEmbedding is the name of another struct I made which extends torch::nn::Module.)

Am I missing something? An example of this would be very helpful.



Edit: Additional context follows below.

I have a header file "model.h" which contains the following:

struct TreeEmbedding : torch::nn::Module {
    TreeEmbedding();
    torch::Tensor forward(Graph tree);
};

struct MyModel : torch::nn::Module{
    size_t embeddingSize;
    TreeEmbedding treeEmbedding;

    MyModel(size_t embeddingSize=10);
    torch::Tensor forward(std::vector<Graph> clauses, std::vector<Graph> contexts);
};

I also have a cpp file "model.cpp" which contains the following:

MyModel::MyModel(size_t embeddingSize) :
    embeddingSize(embeddingSize)
{
    treeEmbedding = register_module("treeEmbedding", TreeEmbedding{});
}

This setup still has the same error as above. The code in the documentation does work (using built-in components like linear layers), but using a custom module does not. After tracking down torch::nn::Linear, it looks as though that is a ModuleHolder (Whatever that is...)

Thanks, Jack


Solution

  • I will accept a better answer if anyone can provide more details, but just in case anyone's wondering, I thought I would put up the little information I was able to find:

    register_module takes in a string as its first argument and its second argument can either be a ModuleHolder (I don't know what this is...) or alternatively it can be a shared_ptr to your module. So here's my example:

    treeEmbedding = register_module<TreeEmbedding>("treeEmbedding", make_shared<TreeEmbedding>());
    

    This seemed to work for me so far.