Full disclosure, I asked this same question on the PyTorch forums about a few days ago and got no reply, so this is technically a repost, but I believe it's still a good question, because I've been unable to find an answer anywhere online. Here goes:
Can you show an example of using register_module with a custom module? The only examples I’ve found online are registering linear layers or convolutional layers as the submodules.
I tried to write my own module and register it with another module and I couldn’t get it to work.
My IDE is telling me no instance of overloaded function "MyModel::register_module" matches the argument list -- argument types are: (const char [14], TreeEmbedding)
(TreeEmbedding is the name of another struct I made which extends torch::nn::Module.)
Am I missing something? An example of this would be very helpful.
Edit: Additional context follows below.
I have a header file "model.h" which contains the following:
struct TreeEmbedding : torch::nn::Module {
TreeEmbedding();
torch::Tensor forward(Graph tree);
};
struct MyModel : torch::nn::Module{
size_t embeddingSize;
TreeEmbedding treeEmbedding;
MyModel(size_t embeddingSize=10);
torch::Tensor forward(std::vector<Graph> clauses, std::vector<Graph> contexts);
};
I also have a cpp file "model.cpp" which contains the following:
MyModel::MyModel(size_t embeddingSize) :
embeddingSize(embeddingSize)
{
treeEmbedding = register_module("treeEmbedding", TreeEmbedding{});
}
This setup still has the same error as above. The code in the documentation does work (using built-in components like linear layers), but using a custom module does not. After tracking down torch::nn::Linear, it looks as though that is a ModuleHolder
(Whatever that is...)
Thanks, Jack
I will accept a better answer if anyone can provide more details, but just in case anyone's wondering, I thought I would put up the little information I was able to find:
register_module takes in a string as its first argument and its second argument can either be a ModuleHolder (I don't know what this is...) or alternatively it can be a shared_ptr to your module. So here's my example:
treeEmbedding = register_module<TreeEmbedding>("treeEmbedding", make_shared<TreeEmbedding>());
This seemed to work for me so far.