Search code examples
pythonpytorchmetal

PyTorch error on MPS (Apple silicon metal)


When I use PyTorch on the CPU, it works fine. When I try to use the mps device it fails. I'm using miniconda for osx-arm64, and I've tried both python 3.8 and 3.11 and both the stable and nightly PyTorch installs.

According to the website (https://pytorch.org/get-started/locally/) mps acceleration is available now without nightly.

The code I've written is as follows:

import torch
mps_device = torch.device("mps")

float_32_tensor1 = torch.tensor([3.0, 6.0, 9.0],
                               dtype=torch.float32,
                               device=mps_device,
                               requires_grad=False)

float_32_tensor2 = torch.tensor([3.0, 6.0, 9.0],
                               dtype=torch.float32,
                               device=mps_device,
                               requires_grad=False)

print(float_32_tensor1.mul(float_32_tensor2))

This results in the following (fairly long) error: https://pastebin.com/svwZj8Ke

First line of error is:

RuntimeError: Failed to create indexing library, error: Error Domain=MTLLibraryErrorDomain Code=3 "program_source:168:1: error: type 'const constant ulong3 *' is not valid for attribute 'buffer'

How would I go about solving this?

edit: meta says pastebin shouldn't be used but the error is too long to include in the question

edit 2: Not that torch.backends.mps.is_available() returns true

edit 3: seems to work normally on the console but Jupyter has this error


Solution

  • Jupyter was using the old kernel (the dev one) even though I switched interpreters. Restarting Jupyter with the new anaconda environment (python 3.8 and the release version of PyTorch) works.