Search code examples
pytorchlibtorch

The equivalent of torch.nn.Parameter for LibTorch


I am trying to port a python PyTorch model to LibTorch in C++.

In python the line of code within a subclass of a torch.Module object self.A = nn.Parameter(A) where A is a torch.tensor object with requires_grad=True.

What would be the equivalent of this for a torch::Tensor in a torch::nn::Module class in C++ ?

The autocomplete in my editor shows the classes ParameterDict, ParameterList, ParameterDictImpl, ParamaterListImpl, but no Parameter. Do I need to wrap it in a list of size 1 or is there something else I'm missing. I wasn't able to find what I needed from a google search or the documentation, but I wasn't sure precisely what to search to be honest.


Solution

  • To register a parameter (or tensor which requires gradients) to a module, you could use:

    m.register_parameter("A", torch::ones({20, 1, 5, 5}), True); in libtorch.