Search code examples
pythonperformanceimage-processingoptimizationpyqtgraph

Improving performance of code under real time data acquisition and understanding weird bug


I'm working with a physical experimental system to track its evolution in real time using a camera. Specifically, I trigger the camera to retrieve images in a rate of 2.5Hz, and so the efficiency of the system is crucial in order to keep up with the experiment.

My current program takes the first image in a dedicated folder, path, and asks the user to select a region of interest on which it performs the desired operation. Next, it calculates the mean brightness per pixel of an acquired image, and of a batch of n images, img_round.

The mean brightness per image batch is plotted in real time with respect to the number of iteration of img_round.

Currently, when I run the program on static data it works great. However, when I attempt to run it in the expected experimental setting of processing where the images are actively added to the folder, I get false values for the brightness in my plot.

In general, I'm afraid my code isn't as efficient as it could be, and I would like to optimize it as much as possible.

Please find the code below:

import os
import cv2
import numpy as np
import pyqtgraph as pg
from scipy.optimize import minimize_scalar
from pyqtgraph.Qt import QtCore, QtGui, QtWidgets
import time 
 
center = (0, 0)
radius = (0)
is_dragging_center = False
is_dragging_radius = False
global avg_brightness_per_img_round
avg_brightness_per_img_round = 0
 
img_round = 5
run_count = 0
 
brightness_history = []
std_history = []
func_history = []
global scatter_item
scatter_item = None
 
def update_display_image():
    global resized_image
    if resized_image is not None:
        display_image = resized_image.copy()
        cv2.circle(display_image, center, radius, (0, 255, 0), 2)
        cv2.circle(display_image, center, 5, (0, 0, 255), thickness=cv2.FILLED)
 
class UpdateDisplaySignal(QtCore.QObject):
    update_display_signal = QtCore.pyqtSignal()
 
update_display_signal_obj = UpdateDisplaySignal()
update_display_signal_obj.update_display_signal.connect(update_display_image)
 
 
def on_mouse(event, x, y, flags, param):
    global center, radius, is_dragging_center, is_dragging_radius
 
    if event == cv2.EVENT_LBUTTONDOWN:
        if np.sqrt((x - center[0]) ** 2 + (y - center[1]) ** 2) < 20:
            is_dragging_center = True
        else:
            is_dragging_radius = True
 
    elif event == cv2.EVENT_LBUTTONUP:
        is_dragging_center = False
        is_dragging_radius = False
 
    elif event == cv2.EVENT_MOUSEMOVE:
        if is_dragging_center:
            center = (x, y)
        elif is_dragging_radius:
            radius = int(np.sqrt((x - center[0]) ** 2 + (y - center[1]) ** 2))
 
app = QtWidgets.QApplication([])
pw = pg.PlotWidget(title='Mean Brightness vs image round')
pw.setLabel('left', 'Mean Brightness')
pw.setLabel('bottom', 'Image round')
scatter = pg.ScatterPlotItem(size=10, pen=pg.mkPen(None), brush=pg.mkBrush(255, 0, 0, 120))
line = pg.PlotDataItem(pen=pg.mkPen(color=(0,0,255), width=2))
pw.addItem(line)
pw.addItem(scatter)
 
def update_scatter():
    global scatter_item
    indices, values = zip(*enumerate(brightness_history, start=1))
    x = list(indices)
    y = list(values)
 
    if scatter_item is None:
        scatter_item = pg.ScatterPlotItem(size=10, pen=pg.mkPen(None), brush=pg.mkBrush(255, 0, 0, 120))
        pw.addItem(scatter_item)
 
    if isinstance(x, int):  
        x = [x]  
 
    if len(x) > 1:
        line.setData(x=x, y=y)
        scatter.setData(x=x, y=y, symbol='o', size=10, pen=pg.mkPen(None), brush=pg.mkBrush(255, 0, 0, 120))
 
        for i, (xi, yi) in enumerate(zip(x, y)):
            label = pg.TextItem(text=f'{yi:.2f}', anchor=(0, 0))
            label.setPos(xi, yi)
            pw.addItem(label)
            
win = QtWidgets.QMainWindow()
win.setCentralWidget(pw)
win.show()
 
path = r'C:\Users\blehe\Desktop\Betatron\images'
 
def calc_xray_count(image_path, center, radius):
    original_image = cv2.imread(image_path, cv2.IMREAD_ANYDEPTH)
 
    median_filtered_image = cv2.medianBlur(original_image, 5)
 
    mask = np.zeros(original_image.shape, dtype=np.uint8)
    cv2.circle(mask, center, radius, 255, thickness=cv2.FILLED)
 
    median_filtered_image += 1  # Avoid not counting black pixels in image
    result = cv2.bitwise_and(median_filtered_image, median_filtered_image, mask=mask)
 
    pixel_count = np.count_nonzero(result)
 
    img_brightness_sum = np.sum(result)
    img_var = np.var(result)
 
    if (pixel_count > 0):
        img_avg_brightness = (img_brightness_sum/pixel_count) -1 # Subtract back to real data
    else:
        img_avg_brightness = 0
 
    return img_avg_brightness, img_var
 
#-----------------------------------------------------------------------

image_files = []
for filename in os.listdir(path):
    if filename.endswith('.TIF'):
        image_files.append(os.path.join(path, filename))
 
first_image_path = image_files[0]
image = cv2.imread(first_image_path)
 
scale_percent = 60 
width = int(image.shape[1] * scale_percent / 100)
height = int(image.shape[0] * scale_percent / 100)
dim = (width, height)
gray_img = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
colored_image = cv2.applyColorMap(gray_img, cv2.COLORMAP_PINK)
resized_image = cv2.resize(colored_image, dim, interpolation=cv2.INTER_AREA)
 
center = (resized_image.shape[1] // 2, resized_image.shape[0] // 2)
radius = min(resized_image.shape[1] // 3, resized_image.shape[0] // 3)
 
cv2.namedWindow("Adjust the circle (press 'Enter' to proceed)")
cv2.setMouseCallback("Adjust the circle (press 'Enter' to proceed)", on_mouse)
 
while True:
    display_image = resized_image.copy()
 
    cv2.circle(display_image, center, radius, (0, 255, 0), 2)
    cv2.circle(display_image, center, 5, (0, 0, 255), thickness=cv2.FILLED)
    cv2.imshow("Adjust the circle (press 'Enter' to proceed)", display_image)
 
    key = cv2.waitKey(1) & 0xFF
    if key == 13: 
        break
 
cv2.destroyAllWindows()
            
center = (int(center[0] / scale_percent * 100), int(center[1] / scale_percent * 100))
radius = int(radius / scale_percent * 100)
 
img_round_brightness_sum = 0
img_round_var_sum = 0
 
def process_images():
    global run_count, img_round_brightness_sum, img_round_var_sum
 
    while run_count < len(os.listdir(path)):   
        for i, image_path in enumerate(image_files, start=1):
            img_avg_brightness, img_var = calc_xray_count(image_path, center, radius)
            img_round_brightness_sum += img_avg_brightness
            img_round_var_sum += img_var
 
            run_count += 1
 
            if run_count % img_round == 0:
                avg_brightness_per_img_round = (img_round_brightness_sum/img_round)
                deviation_per_img_round = np.sqrt(img_round_var_sum/img_round)
 
                brightness_history.append(avg_brightness_per_img_round)
                std_history.append(deviation_per_img_round)
 
                update_scatter()

                img_round_brightness_sum = 0
                img_round_var_sum = 0
 
                img_avg_brightness = 0
                img_var = 0
 
                QtCore.QCoreApplication.processEvents()
                QtCore.QThread.msleep(100)
 
if __name__ == "__main__":
    timer = QtCore.QTimer() 
    timer.timeout.connect(process_images)
    timer.start(100)  
    app.exec_()

I would really appreciate any help. Thank you kindly.


Solution

  • it's really difficult to understand what you're trying to achieve they way you have written your code, so please clean it up. Here is an example and some suggestions for improvements but I quickly gave up on refactoring the entire thing. You can apply a LOT of cleanup and refactoring to this code, for example by breaking your functions into smaller units, and then chaining those functions. Using a class to save the state of your processor instead of relying on global variables everywhere, etc.

    import numpy as np
    import matplotlib.pyplot as plt
    import cv2
    
    class Processor:
        def __init__(self,img):
            self.img = img # Self.img can act as your image buffer
            self.center = (75,75) # I hard coded values but you can calculate and initialize center here
            self.radius = 20
            self.is_dragging_center = False # Initialize all of your global variables here
            self.is_dragging = False
            self.dragging_center = False
            
        def draw_circles(self, img): #Either explicitly pass the image or use self.img
            cv2.circle(img, self.center, self.radius, (0, 255, 0), 2)
            cv2.circle(img, self.center, 5, (0, 0, 255), thickness=cv2.FILLED)
        
        def resize_img(self,img, scale_percent=60):
            width = img.shape[1] * scale_percent // 100
            height = img.shape[0] * scale_percent // 100
            return cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
        
        def transform_color(self, img, cma=cv2.COLORMAP_PINK):
            return cv2.applyColorMap(cv2.cvtColor(img, cv2.COLOR_BGR2GRAY), cmap)
        
        def on_mouse(self, event, x, y, flags, param):
            cx,cy = self.center
            if event == cv2.EVENT_LBUTTONDOWN:
                # <20 should probably not be hard coded?
                # if you want performance do you actually care about dragging in a perfectly circular radius?
                # For example using "if abs(x-cx) + abs(y-cy) < 20" creates a diamond pattern but it's 3 to 4x faster
                # you can also simplify this to ((x - cx) ** 2 + (y - cy) ** 2) < 400
                # "is_dragging_radius" looks to be : "not is_dragging_center", do you really need two variables here?          
    #             self.is_dragging = True
    #             self.dragging_center = ((x - cx) ** 2 + (y - cy) ** 2) < 400
                if np.sqrt((x - cx) ** 2 + (y - cy) ** 2) < 20:
                    self.is_dragging_center = True
                else:
                    self.is_dragging_radius = True
            # All of the "elif"s can be simplified to just "if"s
            elif event == cv2.EVENT_LBUTTONUP:
    #             self.is_dragging = False
                self.is_dragging_center = False
                self.is_dragging_radius = False
    
    #         if event == cv2.EVENT_MOUSEMOVE and self.is_dragging:
    #             if dragging_center:
    #                 center = (x, y)
    #             else:
    #                 radius = int(np.sqrt((x - center[0]) ** 2 + (y - center[1]) ** 2))
            elif event == cv2.EVENT_MOUSEMOVE:
                if is_dragging_center:
                    center = (x, y)
                elif is_dragging_radius:
                    radius = int(np.sqrt((x - center[0]) ** 2 + (y - center[1]) ** 2))
                    
    p = Processor(np.zeros((150,150,3), dtype=np.uint8))
    p.draw_circles(p.img)
    plt.figure()
    plt.imshow(p.img)
    plt.show()
    

    regarding the performance issues and how to process the images while saving them to disk, you can do something like:

    from concurrent.futures import ThreadPoolExecutor as TPE, wait
    # below imports are only used in the second example
    from collections import deque
    from threading import Thread, Event
    # below imports are just used for this example, you can ignore them
    import time
    from random import random, choice
    from string import ascii_letters, digits
    from threading import Lock
    
    def generator(max_delay): # Acts as the data generator, i.e: a camera in your case
        alphabet = list(ascii_letters)+list(digits)
        while True:
            yield ''.join(choice(alphabet) for i in range(20))
            time.sleep(random()*max_delay)
            
            
    def save(x, lock):
        time.sleep(random()*0.5)
        with lock:
            print(f'saved {x}')
    
    def process(x, lock):
        time.sleep(random()*0.8)
        with lock:
            print(f'Processed {x} to {x[:10].swapcase()}')
            
    def producer(generator, queue, event, lock):
        for x in generator:
            if event.is_set():
                break
            queue.append(x)
        with lock:
            print(f'{"*"*20} Done prodcuing. Shutting down. {"*"*20}')
    
    
    lock = Lock() # Lock is used to make sure multiple threads don't print at the same time
    ############################## 
    ##############################
    ###     FIRST EXAMPLE      ###
    ##############################
    ##############################
    # In this example we don't mind if we miss captured frames since we don't have a large enough memory to store
    # all of the incoming data and our processing might be too slow to process in real time
    g = generator(0.3)
    with TPE() as executor:
        start = time.time()
        while (time.time()-start)<5: # Run experiment for 5 seconds
            x = next(g) # Get the next input form "camera"
            # Save the image and process it in parallel (you may want to pass a copy to be processed)
            futures = [executor.submit(save,x,lock), executor.submit(process,x,lock)]
            # Wait for saving and processing to finish before processing the next frame
            wait(futures)
            print('#'*50)
        print('Done processing and saving. Shutting down.')
    
    
    print('_'*100)
    print('_'*100)
    print('_'*100)
    print('_'*100)
    ############################## 
    ##############################
    ###     SECOND EXAMPLE     ###
    ##############################
    ##############################
    # This second example deals with a case that your input rate is higher than your processing throughput
    # If you can process everything in real time OR if you don't mind missing frames, then you can ignore this part.
    # This is only advised if you have enough ram available to hold all the captured data
    
    
    q = deque()
    shutdown_event = Event()
    t = Thread(target=producer, args=(generator(0.1), q, shutdown_event, lock))
    t.start()
    
    with TPE() as executor:
        start = time.time()
        while (time.time()-start)<5: # Run experiment for 5 seconds
            # If queue is empty, wait for it to be populated (in case that input rate is lower than processing throughput)
            while len(q)==0:
                pass
            x = q.popleft() # Get the next "frame" from the queue
            # Save the image and process it in parallel (you may want to pass a copy to be processed)
            futures = [executor.submit(save,x,lock), executor.submit(process,x,lock)]
            # Wait for saving and processing to finish before processing the next frame
            wait(futures)
            print('#'*50)
        # now that your experiment is done, shutdown the producer
        shutdown_event.set()
        # Keep processing until the queue is empty
        while len(q)!=0:
            x = q.popleft() # Get the next input form "camera"
            # Save the image and process it in parallel (you may want to pass a copy to be processed)
            futures = [executor.submit(save,x,lock), executor.submit(process,x,lock)]
            # Wait for saving and processing to finish before processing the next frame
            wait(futures)
            print('#'*50)
        print('Done processing and saving. Shutting down.')
    t.join()
    

    note that the second example continues processing even when the experiment is over and you have stopped capturing frames. It also outputs a lot more data since no input is missed (and the max_delay is 0.1 instead of 0.3 seconds).