Search code examples
pythonpytorchtorchvision

Segfault while importing torchvision.transforms


I'm getting a segfault in python during imports. This code:

import os
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
print("was there")
from torchvision import transforms
print("didn't get there")
from torchvision import datasets
from torchvision import models

returns this:

$ python3 -u classifier.py 
was there
Erreur de segmentation (core dumped)

So torchvision.transforms seems to be responsible. I've tried switching the lines, and torchvision.models fails too.

I've also tried importing torchvision.transforms on it's own and there were no problems. What could possibly cause this?

Edit:

I'm working on Ubuntu 20.04.4 and installed torchvision through pip.


Solution

  • So I moved the torchvision.transforms import to above the matplotlib.pyplot one, and somehow neither torchvision.transforms nor torchvision.models cause a segfault anymore. It still caused a segfault with torchvision.transforms right after matplotlib.pyplot.

    Here is what the final code looks like:

    import os
    from torchvision import transforms
    import matplotlib.pyplot as plt
    import numpy as np
    import torch
    from torch import nn
    from torch import optim
    import torch.nn.functional as F
    from torchvision import datasets
    from torchvision import models
    

    At least my code works, but I feel like there must be an underlying problem that this doesn't adress...