I have two setups - one takes approx. 10 minutes to run the other is still going after an hour:
10 m:
import pretrainedmodels
def resnext50_32x4d(pretrained=False):
pretrained = 'imagenet' if pretrained else None
model = pretrainedmodels.se_resnext50_32x4d(pretrained=pretrained)
return nn.Sequential(*list(model.children()))
learn = cnn_learner(data, resnext50_32x4d, pretrained=True, cut=-2, split_on=lambda m: (m[0][3], m[1]),metrics=[accuracy, error_rate])
Not finishing:
import torchvision.models as models
def get_model(pretrained=True, model_name = 'resnext50_32x4d', **kwargs ):
arch = models.resnext50_32x4d(pretrained, **kwargs )
return arch
learn = Learner(data, get_model(), metrics=[accuracy, error_rate])
This is all copied and hacked from other peoples code so there are parts that I do not understand. But the most perplexing one is why one would be so much faster than the other. I would like to use the second option because its easier for me to understand and I can just swap out the pretrained model to test different ones.
Both architectures are different. I assume you are using pretrained-models.pytorch.
Please notice you are using SE-ResNeXt in your first example and ResNeXt in second (standard one from torchvision
).
The first version uses faster block architecture (Squeeze and Excitation), research paper describing it here.
I'm not sure about exact differences between both architectures and implementations except different building block used, but
you could print
both models and check for differences.
Finally here is a nice article summarizing what Squeeze And Excitation is. Basically you do GlobalAveragePooling
on all channels (im pytorch it would be torch.nn.AdaptiveAvgPoo2d(1)
and flatten
afterwards), push it through two linear layers (with ReLU
activation in-between) finished by sigmoid
in order to get weights for each channel. Finally you multiply the channels by those.
Additionally you are doing something strange with modules transforming them to torch.nn.Sequential
. There may be some logic in forward
call of pretrained network you are removing by copying modules, it may play a part as well.