Search code examples

How to delete tensors from pytorch graph?

I'm doing predictions on images for object detection in a for loop. I've actually ran into the same issue with tensorflow and hoped I could solve it with pytorch. At least now it seems I have found out what the issue is (naively assuming it's the same for tensorflow)

I predict like this

 model = detection.fasterrcnn_resnet50_fpn(pretrained=True, 
    for i in tqdm(range(train.shape[0])):
        image = cv2.imread(train_img_paths[i])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = image.transpose((2, 0, 1))
        image = image / 255.0
        image = np.expand_dims(image, axis=0)
        image = torch.FloatTensor(image)
        image =
        predictions = model(image)[0]

Now through the garbage collector I found that each and every image stays in the graph. Is there away to avoid it?

I have not been able to use dataloader or dataset with the detection models (same with tensorflow hub)


  • Don't forget when you're doing testing to turn gradient accumulation off. You can do this by either wrapping your code like:

    with torch.no_grad():
         out = model(x)

    or if your code is a function, using a decorator to do the same thing:

    def model_proc(model,x):
        return model(x)