Search code examples
pythontensorflowopencvobject-detection

Print Objects from TensorFlow Object Detection API


I'm trying to return list of objects that are returned from the detection or at least the name of the object.

My code:

while True: 

    ret, frame = cap.read()
    image_np = np.array(frame)
    
    input_tensor = tf.convert_to_tensor(np.expand_dims(image_np, 0), dtype=tf.float32)
    detections = detect_fn(input_tensor)
    
    num_detections = int(detections.pop('num_detections'))
    detections = {key: value[0, :num_detections].numpy()
                  for key, value in detections.items()}
    detections['num_detections'] = num_detections

    # detection_classes should be ints.
    detections['detection_classes'] = detections['detection_classes'].astype(np.int64)

    label_id_offset = 1
    image_np_with_detections = image_np.copy()

    viz_utils.visualize_boxes_and_labels_on_image_array(
                image_np_with_detections,
                detections['detection_boxes'],
                detections['detection_classes']+label_id_offset,
                detections['detection_scores'],
                category_index,
                use_normalized_coordinates=True,
                max_boxes_to_draw=1,
                min_score_thresh=.85,
                agnostic_mode=False)
    classes=detections['detection_classes'].astype(np.int64)
    scores=detections['detection_scores']
    
    #label_names = [i[0] for i in category_index.items()]
    #label_names = np.array(label_names)
    #print(label_names[detections['detection_classes']])
    
    
    
    print ([category_index.get(value) for index,value in enumerate(classes[0]) if scores[0,index] > 0.8])
    cv2.imshow('object detection',  cv2.resize(image_np_with_detections, (800, 600)))

    
    if cv2.waitKey(1) & 0xFF == ord('q'):
        cap.release()
        break

It gives me the following result:

TypeError                                 Traceback (most recent call last)
<ipython-input-12-a749c9f2a4b4> in <module>
     37 
     38 
---> 39     print ([category_index.get(value) for index,value in enumerate(classes[0]) if scores[0,index] > 0.8])
     40     cv2.imshow('object detection',  cv2.resize(image_np_with_detections, (800, 600)))
     41 

TypeError: 'numpy.int64' object is not iterable

Solution

  • I think you want to do something like,

    print ([category_index.get(class_) for class_, score in zip(classes, scores) if score > 0.8])
    

    The error you are getting is because you are passing classes[0], the first element in the classes list. The single element is not iterable by enumerate().

    As a side note, the class_ variable is named this way because class is a reserved keyword. PEP8 recommends appending an underscore when there is a collision. https://www.python.org/dev/peps/pep-0008/#function-and-method-arguments