Search code examples
pythonpytorchresnet

Extract features from pretrained resnet50 in pytorch


Hy guys, i want to extract the in_features of Fully connected layer of my pretrained resnet50.

I create before a method that give me the vector of features:

def get_vector(image):

#layer = model._modules.get('fc')

layer = model.fc
my_embedding = torch.zeros(2048) #2048 is the in_features of FC , output of avgpool

def copy_data(m, i, o):

    my_embedding.copy_(o.data)


h = layer.register_forward_hook(copy_data)
tmp = model(image)

h.remove()

# return the vector
return my_embedding

after I call this method here:

column = ["FlickrID", "Features"]

path = "./train_dataset/train_imgs/"

pathCSV = "./train_dataset/features/img_info_TRAIN.csv"



f_id=[]
features_extr=[]

df = pd.DataFrame(columns=column)


tr=transforms.Compose([transforms.Resize(256),
                       transforms.CenterCrop(224),
                       transforms.ToTensor(),
                       transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])



test = Dataset(path, pathCSV, transform=tr)

test_loader = DataLoader(test, batch_size=1, num_workers=2, shuffle = False)



#Leggiamo le immagini
for batch in test_loader:
    nome = batch['FlickrID']
    f_id.append(nome)
    image = batch['image']



    #print(image)
    with torch.no_grad():
        pred = get_vector(image)

    features_extr.append(pred)

df["FlickrID"] = f_id
df["Features"] = features_extr  


df.to_hdf("Places.h5", key='df', mode='w')

I have an error like this: output with shape [2048] doesn't match the broadcast shape [1, 2048, 1, 2048]

How can I take the in_feature of Fully Connected of this resnet50? The Dataset is a customized Dataset class.

Sorry for my bad english


Solution

  • The model takes batched inputs, that means the input to the fully connected layer has size [batch_size, 2048]. Because you are using a batch size of 1, that becomes [1, 2048]. Therefore that doesn't fit into a the tensor torch.zeros(2048), so it should be torch.zeros(1, 2048) instead.

    You are also trying to use the output (o) of the layer model.fc instead of the input (i).

    Besides that, using hooks is overly complicated for this and a much easier way to get features is to modify the model by replacing model.fc with nn.Identity, which just returns the input as the output, and since the features are its input, the output of the entire model will be the features.

    model.fc = nn.Identity()
    
    features = model(image)