Search code examples
pytorchgpu

Creating tensors on M1 GPU by default on PyTorch using jupyter


Right now, if I want to create a tensor on gpu, I have to do it manually. For context, I'm sure that GPU support is available since

print(torch.backends.mps.is_available())# this ensures that the current current PyTorch installation was built with MPS activated.
print(torch.backends.mps.is_built())

returns True.

I've been doing this every time:

device = torch.device("mps")
a = torch.randn((), device=device, dtype=dtype)

Is there a way to specify, for a jupyter notebook, that all my tensors are supposed to be run on the GPU?


Solution

  • The convenient way

    There is no convenient way to set default device to MPS as of 2022-12-22, per discussion on this issue.


    The inconvenient way

    You can accomplish the objective of 'I don't want to specify device= for tensor constructors, just use MPS' by intercepting calls to tensor constructors:

    class MPSMode(torch.overrides.TorchFunctionMode):
        def __init__(self):
            # incomplete list; see link above for the full list
            self.constructors = {getattr(torch, x) for x in "empty ones arange eye full fill linspace rand randn randint randperm range zeros tensor as_tensor".split()}
        def __torch_function__(self, func, types, args=(), kwargs=None):
            if kwargs is None:
                kwargs = {}
            if func in self.constructors:
                if 'device' not in kwargs:
                    kwargs['device'] = 'mps'
            return func(*args, **kwargs)
    
    # sensible usage
    with MPSMode():
        print(torch.empty(1).device) # prints mps:0
    
    # sneaky usage
    MPSMode().__enter__()
    print(torch.empty(1).device) # prints mps:0
    

    The recommended way:

    I would lean towards just putting your device in a config at the top of your notebook and using it explicitly:

    class Conf: dev = torch.device("mps")
    # ...
    a = torch.randn(1, device=Conf.dev)
    

    This requires you to type device=Conf.dev throughout the code. But you can easily switch your code to different devices, and you don't have any implicit global state to worry about.