Search code examples
pytorchdrivernvidia

The apparently fine, simple pytorch code throws a RuntimeError


RuntimeError: CUDA error: CUBLAS_STATUS_INVALID_VALUE when calling `cublasSgemm( handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)

I changed my GPUs from RTX 3090 to RTX A6000 Since then, the code using the gpu gave the above error.

I wrote and ran a really simple code like the one below, but the error still occurred.

import torch
import torch.nn as nn


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 10)

    def forward(self, x):
        return self.linear(x)


if __name__ == "__main__":
    x = torch.randn(10).cuda()
    model = Model().cuda()

    print(model(x))

Where to go to solve the problem? Is it a graphics driver issue?


Solution

  • I solved this by

    1. re-install cuda and graphic driver
    2. create new conda env and re-install pytorch