Search code examples
pytorchjittorchscript

Tracing Tensor Sizes in TorchScript


I'm exporting a PyTorch model via TorchScript tracing, but I'm facing issues. Specifically, I have to perform some operations on tensor sizes, but the JIT compilers hardcodes the variable shapes as constants, braking compatibility with tensor of different sizes.

For example, create the class:

class Foo(nn.Module):
    """Toy class that plays with tensor shape to showcase tracing issue.

    It creates a new tensor with the same shape as the input one, except
    for the last dimension, which is doubled. This new tensor is filled
    based on the values of the input.
    """
    def __init__(self):
        nn.Module.__init__(self)

    def forward(self, x):
        new_shape = (x.shape[0], 2*x.shape[1])  # incriminated instruction
        x2 = torch.empty(size=new_shape)
        x2[:, ::2] = x
        x2[:, 1::2] = x + 1
        return x2

and run the test code:

x = torch.randn((3, 5))  # create example input

foo = Foo()
traced_foo = torch.jit.trace(foo, x)  # trace
print(traced_foo(x).shape)  # obviously this works
print(traced_foo(x[:, :4]).shape)  # but fails with a different shape!

I could solve the issue by scripting, but in this case I really need to use tracing. Moreover, I think that tracing should be able to handle tensor size manipulations correctly.


Solution

  • but in this case I really need to use tracing

    You can freely mix torch.script and torch.jit wherever needed. For example one could do this:

    import torch
    
    
    class MySuperModel(torch.nn.Module):
        def __init__(self, *args, **kwargs):
            super().__init__()
            self.scripted = torch.jit.script(Foo(*args, **kwargs))
            self.traced = Bar(*args, **kwargs)
    
        def forward(self, data):
            return self.scripted(self.traced(data))
    
    model = MySuperModel()
    torch.jit.trace(model, (input1, input2))
    

    You could also move part of the functionality dependent on shape to separate function and decorate it with @torch.jit.script:

    @torch.jit.script
    def _forward_impl(x):
        new_shape = (x.shape[0], 2*x.shape[1])  # incriminated instruction
        x2 = torch.empty(size=new_shape)
        x2[:, ::2] = x
        x2[:, 1::2] = x + 1
        return x2
    
    class Foo(nn.Module):
        def forward(self, x):
            return _forward_impl(x)
    

    There is no other way than script for that as it has to understand your code. With tracing it merely records operations you perform on the tensor and has no knowledge of control flow dependent on data (or shape of data).

    Anyway, this should cover most of the cases and if it doesn't you should be more specific.