Search code examples
pythonpytorch

How do I cast a raw pointer to a pytorch tensor of a specific shape?


I get a raw pointer from a C++ library which I would like to interpret (in a "reinterpret_cast-like fashion) as a pytorch tensor of a specific shape. Since the code is executed in a performance critical section, I really want to make sure that no heap allocations and/or copy operations are performed.

Here is what I got right now:

def as_tensor(pointer, shape):
    return torch.from_numpy(numpy.array(numpy.ctypeslib.as_array(pointer, shape = shape)))

shape = (2, 3, 4)
x = torch.zeros(shape)

p = ctypes.cast(x.data_ptr(), ctypes.POINTER(ctypes.c_float))
y = as_tensor(p, shape)

Is it really necessary to cast to a numpy array before? And I'm also not 100% sure if the call to numpy.array(...) doesn't copy the content of what the as_array() call is pointing to.


Solution

  • You can create a one-dimensional ctypes array object from pointer and shape. It implements the buffer protocol, so it can be converted to a one-dimensional tensor which is finally reshaped.

    The code at the end shows that x and y share the same memory.

    import torch
    import ctypes
    from math import prod
    
    # It additionally needs the ctypes type as torch type
    def as_tensor(pointer, shape, torch_type):
        arr = (pointer._type_ * prod(shape)).from_address(
            ctypes.addressof(pointer.contents))
        
        return torch.frombuffer(arr, dtype=torch_type).view(*shape)
    
    shape = (2, 3, 4)
    x = torch.zeros(shape)
    
    p = ctypes.cast(x.data_ptr(), ctypes.POINTER(ctypes.c_float))
    
    y = as_tensor(p, shape, torch.float)
    
    print(y)  # Print created tensor
    
    x[1,1,0] = 3.  # Modify original
    
    print(y)  # Print again
    

    Output:

    tensor([[[0., 0., 0., 0.],
             [0., 0., 0., 0.],
             [0., 0., 0., 0.]],
    
            [[0., 0., 0., 0.],
             [0., 0., 0., 0.],
             [0., 0., 0., 0.]]])
    tensor([[[0., 0., 0., 0.],
             [0., 0., 0., 0.],
             [0., 0., 0., 0.]],
    
            [[0., 0., 0., 0.],
             [3., 0., 0., 0.],
             [0., 0., 0., 0.]]])