Search code examples
pytorchfast-ai

Semantic Segmentation runtime error at loss function


I am using a costume model for segmentation (SETRModel). The model output shape is (nBatch, 256, 256) and the code below confirms it (note that the channel is squeezed out). The target shape is the same (It’s a PILMask).

When I start training, I get a runtime error (see below) related to the loss function. What am I doing wrong?

```
size = 480
half= (256, 256) 
splitter = FuncSplitter(lambda o: Path(o).parent.name == 'validation')

dblock = DataBlock(blocks=(ImageBlock, MaskBlock(codes)),
               get_items=get_relevant_images,
               splitter=splitter,
               get_y=get_mask, 
               item_tfms=Resize((size,size)),
               batch_tfms=[*aug_transforms(size=half), Normalize.from_stats(*imagenet_stats)])

dls = dblock.dataloaders(path/'images', bs=4)

model = SETRModel(patch_size=(32, 32), 
            in_channels=3, 
            out_channels=1, 
            hidden_size=1024, 
            num_hidden_layers=8, 
            num_attention_heads=16, 
            decode_features=[512, 256, 128, 64])


# Create a Learner using a custom model
loss = nn.BCEWithLogitsLoss()
learn = Learner(dls, model, loss_func=loss, lr=1.0e-4, cbs=callbacks, metrics=[Dice()])


# Let's test and make sure the loss function is happy with its inputs
learn.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

t1 = torch.rand(4, 3, 256, 256).to(device)
print("input: " + str(t1.shape))

pred = learn.model(t1).to(device)
print("output: " + str(pred.shape))

# prints this:
# input: torch.Size([4, 3, 256, 256])
# output: torch.Size([4, 256, 256])

target = next(iter(learn.dls.train))[1]
target = target.type(torch.float32).to(device)
target.size(), pred.size()

# prints this:
# (torch.Size([4, 256, 256]), torch.Size([4, 256, 256]))

loss(pred, target)

# prints this:
# TensorMask(0.6844, device='cuda:0', grad_fn=<AliasBackward>)

# so, the loss function is happy with its inputs

learn.fine_tune(50)

# prints this:
# ---------------------------------------------------------------------------
# RuntimeError                              Traceback (most recent call last)
# <ipython-input-114-0e514c73651a> in <module>()
# ----> 1 learn.fine_tune(50)

# 19 frames
# /usr/local/lib/python3.7/dist-packages/torch/nn/functional.py in         binary_cross_entropy_with_logits(input, target, weight, size_average, reduce, reduction,     pos_weight)
#    2827 pixel_shuffle = _add_docstr(torch.pixel_shuffle, r"""
#    2828 Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)` to a
# -> 2829 tensor of shape :math:`(*, C, H \times r, W \times r)`.
#    2830 
#    2831 See :class:`~torch.nn.PixelShuffle` for details.

# RuntimeError: result type Float can't be cast to the desired output type Long

Solution

  • This is something that happens when you use PyTorch inside fastai (I believe this should be fixed).

    Just create custom loss_func. For example:

    def loss_func(output, target): return CrossEntropyLossFlat()(out, targ.long())
    

    and pass it when creating the DataBlock:

    dblock = DataBlock(... , loss_func=loss_func, ...)