Search code examples
pytorch

PyTorch: how to use torchvision.transforms.AugMIx with torch.float32?


PyTorch: how to use torchvision.transforms.AugMIx with torch.float32?

I am trying to apply data augmentation in image dataset by using torchvision.transforms.AugMIx, but I have the following error: TypeError: Only torch.uint8 image tensors are supported, but found torch.float32. I tried to convert it to int, but I have another error.


My code where I am trying to use the AugMix function:

transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize((224, 224)),  # resize to 224*224
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),  # normalization
        torchvision.transforms.AugMix()
    ]
)
to_tensor = torchvision.transforms.ToTensor()
Image.MAX_IMAGE_PIXELS = None


class BreastDataset(torch.utils.data.Dataset):

    def __init__(self, json_path, data_dir_path='./dataset', clinical_data_path=None, is_preloading=True):
        self.data_dir_path = data_dir_path
        self.is_preloading = is_preloading
  
        with open(json_path) as f:
            print(f"load data from {json_path}")
            self.json_data = json.load(f)

    def __len__(self):
        return len(self.json_data)

    def __getitem__(self, index):
        label = int(self.json_data[index]["label"])
        patient_id = self.json_data[index]["id"]
        patch_paths = self.json_data[index]["patch_paths"]

        data = {}
        if self.is_preloading:
            data["bag_tensor"] = self.bag_tensor_list[index]
        else:
            data["bag_tensor"] = self.load_bag_tensor([os.path.join(self.data_dir_path, p_path) for p_path in patch_paths])

        data["label"] = label
        data["patient_id"] = patient_id
        data["patch_paths"] = patch_paths

        return data

    
    def load_bag_tensor(self, patch_paths):
        """Load a bag data as tensor with shape [N, C, H, W]"""

        patch_tensor_list = []
        for p_path in patch_paths:
            patch = Image.open(p_path).convert("RGB")
            patch_tensor = transform(patch)  # [C, H, W]
            patch_tensor = torch.unsqueeze(patch_tensor, dim=0)  # [1, C, H, W]
            patch_tensor_list.append(patch_tensor)

        bag_tensor = torch.cat(patch_tensor_list, dim=0)  # [N, C, H, W]

        return bag_tensor

Any help is appreciated! Thank you in advance!


Solution

  • For me applying AugMix first and then ToTensor() worked

    transformation = transforms.Compose([
                        transforms.AugMix(severity= 6,mixture_width=2),
                        transforms.ToTensor(),
                        transforms.RandomErasing(),
                        transforms.RandomGrayscale(p = 0.35)
                        ])