Search code examples
pythonpytorchimage-segmentationdetectroncoco

Error Training Custom COCO Dataset with Detectron2


I'm trying to train a custom COCO-format dataset with Detectron2 on PyTorch. My datasets are json files with the aforementioned COCO-format, with each item in the "annotations" section looking like this:

annotation sample

The code for setting up Detectron2 and registering the training & validation datasets are as follows:

from detectron2.data.datasets import register_coco_instances
for d in ["train", "validation"]:
    register_coco_instances(f"segmentation_{d}", {}, f"/content/drive/MyDrive/Segmentation Annotations/{d}.json", f"/content/drive/MyDrive/Segmentation Annotations/imgs")

from detectron2.engine import DefaultTrainer
from detectron2.config import get_cfg

cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.DATASETS.TRAIN = ("segmentation_train",)
cfg.DATASETS.TEST = ()
cfg.DATALOADER.NUM_WORKERS = 2
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")  
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.BASE_LR = 0.00025  
cfg.SOLVER.MAX_ITER = 1000    
cfg.SOLVER.STEPS = []        
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 20  

os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = DefaultTrainer(cfg) 
trainer.resume_or_load(resume=False)
trainer.train()

However, when I run the training, I get the following error after the first iteration:

KeyError                                  Traceback (most recent call last)
<ipython-input-12-2aaec108c313> in <module>()
     17 trainer = DefaultTrainer(cfg)
     18 trainer.resume_or_load(resume=False)
---> 19 trainer.train()

8 frames
/usr/local/lib/python3.7/dist-packages/detectron2/engine/defaults.py in train(self)
    482             OrderedDict of results, if evaluation is enabled. Otherwise None.
    483         """
--> 484         super().train(self.start_iter, self.max_iter)
    485         if len(self.cfg.TEST.EXPECTED_RESULTS) and comm.is_main_process():
    486             assert hasattr(

/usr/local/lib/python3.7/dist-packages/detectron2/engine/train_loop.py in train(self, start_iter, max_iter)
    147                 for self.iter in range(start_iter, max_iter):
    148                     self.before_step()
--> 149                     self.run_step()
    150                     self.after_step()
    151                 # self.iter == max_iter can be used by `after_train` to

/usr/local/lib/python3.7/dist-packages/detectron2/engine/defaults.py in run_step(self)
    492     def run_step(self):
    493         self._trainer.iter = self.iter
--> 494         self._trainer.run_step()
    495 
    496     @classmethod

/usr/local/lib/python3.7/dist-packages/detectron2/engine/train_loop.py in run_step(self)
    265         If you want to do something with the data, you can wrap the dataloader.
    266         """
--> 267         data = next(self._data_loader_iter)
    268         data_time = time.perf_counter() - start
    269 

/usr/local/lib/python3.7/dist-packages/detectron2/data/common.py in __iter__(self)
    232 
    233     def __iter__(self):
--> 234         for d in self.dataset:
    235             w, h = d["width"], d["height"]
    236             bucket_id = 0 if w > h else 1

/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py in __next__(self)
    519             if self._sampler_iter is None:
    520                 self._reset()
--> 521             data = self._next_data()
    522             self._num_yielded += 1
    523             if self._dataset_kind == _DatasetKind.Iterable and \

/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py in _next_data(self)
   1181             if len(self._task_info[self._rcvd_idx]) == 2:
   1182                 data = self._task_info.pop(self._rcvd_idx)[1]
-> 1183                 return self._process_data(data)
   1184 
   1185             assert not self._shutdown and self._tasks_outstanding > 0

/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py in _process_data(self, data)
   1227         self._try_put_index()
   1228         if isinstance(data, ExceptionWrapper):
-> 1229             data.reraise()
   1230         return data
   1231 

/usr/local/lib/python3.7/dist-packages/torch/_utils.py in reraise(self)
    423             # have message field
    424             raise self.exc_type(message=msg)
--> 425         raise self.exc_type(msg)
    426 
    427 

KeyError: Caught KeyError in DataLoader worker process 1.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/fetch.py", line 28, in fetch
    data.append(next(self.dataset_iter))
  File "/usr/local/lib/python3.7/dist-packages/detectron2/data/common.py", line 201, in __iter__
    yield self.dataset[idx]
  File "/usr/local/lib/python3.7/dist-packages/detectron2/data/common.py", line 90, in __getitem__
    data = self._map_func(self._dataset[cur_idx])
  File "/usr/local/lib/python3.7/dist-packages/detectron2/utils/serialize.py", line 26, in __call__
    return self._obj(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/detectron2/data/dataset_mapper.py", line 189, in __call__
    self._transform_annotations(dataset_dict, transforms, image_shape)
  File "/usr/local/lib/python3.7/dist-packages/detectron2/data/dataset_mapper.py", line 132, in _transform_annotations
    annos, image_shape, mask_format=self.instance_mask_format
  File "/usr/local/lib/python3.7/dist-packages/detectron2/data/detection_utils.py", line 400, in annotations_to_instances
    segms = [obj["segmentation"] for obj in annos]
  File "/usr/local/lib/python3.7/dist-packages/detectron2/data/detection_utils.py", line 400, in <listcomp>
    segms = [obj["segmentation"] for obj in annos]
KeyError: 'segmentation'

You all have any idea why this might be happening, and if so, what can be done to fix it? Any input is appreciated.

Thanks!


Solution

  • It's difficult to give a concrete answer without looking at the full annotation file, but a KeyError exception is raised when trying to access a key that is not in a dictionary. From the error message you've posted, this key seems to be 'segmentation'.

    This is not in your code snippet, but before even getting into network training, have you done any exploration/inspections using the registered datasets? Doing some basic exploration or inspections would expose any problems with your dataset so you can fix them early in your development process (as opposed to letting the trainer catch them, in which case the error messages could get long and confounding).

    In any case, for your specific issue, you can take the registered training dataset and check if all annotations have the 'segmentation' field. A simple code snippet to do this below.

    # Register datasets
    from detectron2.data.datasets import register_coco_instances
    for d in ["train", "validation"]:
        register_coco_instances(f"segmentation_{d}", {}, f"/content/drive/MyDrive/Segmentation Annotations/{d}.json", f"/content/drive/MyDrive/Segmentation Annotations/imgs")
    
    # Check if all annotations in the registered training set have the segmentation field
    from detectron2.data import DatasetCatalog
    
    dataset_dicts_train = DatasetCatalog.get('segmentation_train')
    
    for d in dataset_dicts_train:
        for obj in d['annotations']:
            if 'segmentation' not in obj:
                print(f'{d["file_name"]} has an annotation with no segmentation field')
    

    It would be strange if some images have annotations with no 'segmentation' fields in them, but it would indicate that there's some problem in your upstream annotation process.