Search code examples
pythondeep-learningcomputer-visionpytorchtorchvision

Is there a way to load torchvision model by string?


Currently, I load pretrained torchvision model using following code:

import torchvision
torchvision.models.resnet101(pretrained=True)

However, I'd love to have model name as string parameter and then load the pretrained model using that string. A pseudo-code that would do so would be something like:

model_name = 'resnet101'
torchvision.models.get(model_name)(pretrained=True)

Is there a way to accomplish this in a rather simple manner?


Solution

  • You can use getattr

    getattr(torchvision.models, 'resnet101')(pretrained=True)