Search code examples
c++pytorchlibtorch

Create a torch::Tensor in C++ to change the shape


I have a tensor array, and I want to change the shape of tensor. I tried to use torch.view, but it raise an exception that "shape[1] is invalid for input of size 10000". Anyone can give me a tips for the error information?

int shape[] = {1,1,100,100};
torch::Tensor img = torch::zeros((100,100),torch::KF32);
torch::Tensor tmg = img.view(*shape);

Solution

  • C++ is not python so constructs like unpacking with * obviously will not work. Same goes for (, ), you should use object which can be "auto-casted" to IntArrayRef.

    Creating objects basics

    ArrayRef is a template class which means it can hold different C++ types and IntArrayRef is an alias for ArrayRef<int>. This class has a few constructors (e.g. from standard C-style array, std::vector, std::array or std::initializer_list).

    Both torch::zeros and view method of torch::Tensor require this exact object.

    What you can do is:

    /* auto to feel more "Pythonic" */
    auto img = torch::zeros({100, 100}, torch::kF32);
    auto tmg = img.view({1, 1, 100, 100});
    

    {1, 1, 100, 100} is std::initializer_list<int> type so ArrayRef<int> (a.k.a. IntArrayRef) can be constructed from it (probably moved as this object is an rvalue).

    Same thing happens for torch::zeros.

    Easier way for this case

    What you have here could be accomplished easier though with unsqueeze like this:

    auto img = torch::zeros({100, 100}, torch::kF32);
    auto unsqueezed = img.unsqueeze(0).unsqueeze(0);
    

    Where 0 in the dimension.

    About libtorch

    All in all read the reference and check types at least if you want to work with C++. I agree docs could use some work but if you know something about C++ it shouldn't be too hard to follow even into source code which might be needed sometimes.