Search code examples
pythonc++indexingpytorchlibtorch

Is there an analogy for Python's array slicing in C++ (libtorch)?


In Python with PyTorch, if you have an array:

torch.linspace(0, 10, 10)

you can use e.g. only the first three elements by saying

reduced_tensor = torch.linspace(0, 10, 10)[:4].

Is there an analog to the [:] array slicing in C++/libtorch? If not, how can I achieve this easily?


Solution

  • Yes, You can use Slice and index in libtorch. You can do:

    auto tensor = torch::linspace(0, 10, 10).index({ Slice(None, 4) });
    

    You can read more about indexing here.
    Basically as its indicated in the documentation :

    The main difference is that, instead of using the []-operator similar to the Python API syntax, in the C++ API the indexing methods are:

    torch::Tensor::index (link)

    torch::Tensor::index_put_ (link)

    It’s also important to note that index types such as None / Ellipsis / Slice live in the torch::indexing namespace, and it’s recommended to put using namespace torch::indexing before any indexing code for convenient use of those index types.

    For the convenience here is some of the Python vs C++ conversions taken from the link I just gave:

    Here are some examples of translating Python indexing code to C++:

    Getter
    ------
    
    +----------------------------------------------------------+--------------------------------------------------------------------------------------+
    | Python                                                   | C++  (assuming  using namespace torch::indexing )                                    |
    +==========================================================+======================================================================================+
    |  tensor[None]                                            |  tensor.index({None})                                                                |
    +----------------------------------------------------------+--------------------------------------------------------------------------------------+
    |  tensor[Ellipsis, ...]                                   |  tensor.index({Ellipsis, "..."})                                                     |
    +----------------------------------------------------------+--------------------------------------------------------------------------------------+
    |  tensor[1, 2]                                            |  tensor.index({1, 2})                                                                |
    +----------------------------------------------------------+--------------------------------------------------------------------------------------+
    |  tensor[True, False]                                     |  tensor.index({true, false})                                                         |
    +----------------------------------------------------------+--------------------------------------------------------------------------------------+
    |  tensor[1::2]                                            |  tensor.index({Slice(1, None, 2)})                                                   |
    +----------------------------------------------------------+--------------------------------------------------------------------------------------+
    |  tensor[torch.tensor([1, 2])]                            |  tensor.index({torch::tensor({1, 2})})                                               |
    +----------------------------------------------------------+--------------------------------------------------------------------------------------+
    |  tensor[..., 0, True, 1::2, torch.tensor([1, 2])]        |  tensor.index({"...", 0, true, Slice(1, None, 2), torch::tensor({1, 2})})            |
    +----------------------------------------------------------+--------------------------------------------------------------------------------------+
    
    
    Translating between Python/C++ index types
    ------------------------------------------
    
    The one-to-one translation between Python and C++ index types is as follows:
    
    +-------------------------+------------------------------------------------------------------------+
    | Python                  | C++ (assuming  using namespace torch::indexing )                       |
    +=========================+========================================================================+
    |  None                   |  None                                                                  |
    +-------------------------+------------------------------------------------------------------------+
    |  Ellipsis               |  Ellipsis                                                              |
    +-------------------------+------------------------------------------------------------------------+
    |  ...                    |  "..."                                                                 |
    +-------------------------+------------------------------------------------------------------------+
    |  123                    |  123                                                                   |
    +-------------------------+------------------------------------------------------------------------+
    |  True                   |  true                                                                  |
    +-------------------------+------------------------------------------------------------------------+
    |  False                  |  false                                                                 |
    +-------------------------+------------------------------------------------------------------------+
    |  :  or  ::              |  Slice()  or  Slice(None, None)  or  Slice(None, None, None)           |
    +-------------------------+------------------------------------------------------------------------+
    |  1:  or  1::            |  Slice(1, None)  or  Slice(1, None, None)                              |
    +-------------------------+------------------------------------------------------------------------+
    |  :3  or  :3:            |  Slice(None, 3)  or  Slice(None, 3, None)                              |
    +-------------------------+------------------------------------------------------------------------+
    |  ::2                    |  Slice(None, None, 2)                                                  |
    +-------------------------+------------------------------------------------------------------------+
    |  1:3                    |  Slice(1, 3)                                                           |
    +-------------------------+------------------------------------------------------------------------+
    |  1::2                   |  Slice(1, None, 2)                                                     |
    +-------------------------+------------------------------------------------------------------------+
    |  :3:2                   |  Slice(None, 3, 2)                                                     |
    +-------------------------+------------------------------------------------------------------------+
    |  1:3:2                  |  Slice(1, 3, 2)                                                        |
    +-------------------------+------------------------------------------------------------------------+
    |  torch.tensor([1, 2])   |  torch::tensor({1, 2})                                                 |
    +-------------------------+------------------------------------------------------------------------+