Search code examples
pythonnumpymachine-learningpytorchsegment-anything

segment_anything causing error with numpy.uint8


I am trying to run https://github.com/facebookresearch/segment-anything/blob/main/notebooks/onnx_model_example.ipynb locally, on an M2 MacBook with Sonoma 14.5. However, I keep running into the following error at step 11:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[75], line 1
----> 1 masks = mask_generator.generate(image)

File ~/opt/anaconda3/envs/ve_env/lib/python3.9/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File ~/opt/anaconda3/envs/ve_env/lib/python3.9/site-packages/segment_anything/automatic_mask_generator.py:163, in SamAutomaticMaskGenerator.generate(self, image)
    138 """
    139 Generates masks for the given image.
    140 
   (...)
    159          the mask, given in XYWH format.
    160 """
    162 # Generate masks
--> 163 mask_data = self._generate_masks(image)
    165 # Filter small disconnected regions and holes in masks
    166 if self.min_mask_region_area > 0:

File ~/opt/anaconda3/envs/ve_env/lib/python3.9/site-packages/segment_anything/automatic_mask_generator.py:206, in SamAutomaticMaskGenerator._generate_masks(self, image)
    204 data = MaskData()
    205 for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
--> 206     crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
    207     data.cat(crop_data)
    209 # Remove duplicate masks between crops

File ~/opt/anaconda3/envs/ve_env/lib/python3.9/site-packages/segment_anything/automatic_mask_generator.py:236, in SamAutomaticMaskGenerator._process_crop(self, image, crop_box, crop_layer_idx, orig_size)
    234 cropped_im = image[y0:y1, x0:x1, :]
    235 cropped_im_size = cropped_im.shape[:2]
--> 236 self.predictor.set_image(cropped_im)
    238 # Get points for this crop
    239 points_scale = np.array(cropped_im_size)[None, ::-1]

File ~/opt/anaconda3/envs/ve_env/lib/python3.9/site-packages/segment_anything/predictor.py:57, in SamPredictor.set_image(self, image, image_format)
     55 # Transform the image to the form expected by the model
     56 input_image = self.transform.apply_image(image)
---> 57 input_image_torch = torch.as_tensor(input_image, device=self.device)
     58 input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
     60 self.set_torch_image(input_image_torch, image.shape[:2])

RuntimeError: Could not infer dtype of numpy.uint8

I am using a conda environment with Python 3.9.19, and also tested with Python 3.11. Based on online comments I suspected this to be an issue with numpy versions, but having tried multiple versions I cannot find the correct combination. I am currently trying with the following:

numpy==1.24.4
torch==1.9.0
torchvision==0.10.0
opencv-python==4.10.0.84

Running the same notebook on Google Colab works fine, and the versions indicated there are:

import numpy as np
import torch
import cv2

print(np.__version__)
print(torch.__version__)
print(cv2.__version__)

1.25.2
2.3.0+cu121
4.8.0

This is using Python 3.10.12. These versions are not available on Mac, so I am stuck.

How can I find out why numpy.uint8 is not being recognized, and how can I fix this error? Most online comments point to upgrading numpy, but I have tried several numpy versions without luck. Any help is appreciated.


Solution

  • For someone else that runs into the same issue, the cause seems to be an issue with JetBrains PyCharm's Jupyter Notebook support. I am filing a bug report with them as well. Running jupyter-notebook externally showed the correct version of numpy being used and the code works as expected.