Search code examples
pythonmultithreadingpyqtthread-safetyqthread

Communicating properly between threads in a QApplication


This is a minimum reproducible example of a much larger application where a plot is generated from data coming in from a separate thread. Here I just generate some dummy data but the simulator, and value reader are both much more complex in the actual code. The user should be able to start/stop the plotting (interrupt the data stream temporarily), but the data can come in with no end. I communicate between threads using a signal class. The signal class is initialized in the main window and is then passed to all other classes to establish communication. There are different modes that can be switched between to generate different behavior in the long-running portion. This appears to work, however I am worried about a couple of things:

  1. If you see the logging info, the thread receiving signals inside of the WorkerThread class is the "main thread". Is there an issue with the signals being received on the main thread inside of the WorkerThread class? Is there a more proper way to set up this communication? (I don't want to have each class in the application send their own signals. Having a signals class is much easier to track)
  2. If there is a third thread, where decisions are made based on information from the second thread, how should communication be handled so that changes aren't made based off of old data? I can easily imagine a scenario where the third thread could make wrong decisions because of how the signal is getting to the third thread. Both of these separate threads would still need to have their own communication with the main thread.

Please feel free to let me know if you see any other issues/suggestions with this code.

from PyQt5.QtCore import QThread, pyqtSignal, QObject
from PyQt5.QtWidgets import QMainWindow, QApplication, QWidget, QHBoxLayout, \
    QVBoxLayout, QPushButton
import math
import time
import random
import logging
import pyqtgraph as pg
from statistics import mean
import sys
import threading
from random import randint
import numpy as np
from collections import deque

lock = threading.Lock()

log = logging.getLogger('thread_tester')
log.setLevel(logging.DEBUG)
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(thread)d - %(message)s')
ch.setFormatter(formatter)
log.addHandler(ch)

timeit = time.time()


class Signals(QObject):
    changeMode = pyqtSignal(str)
    refreshGraphs = pyqtSignal(float, int)
    start = pyqtSignal()
    stop = pyqtSignal(bool)
    calibrationVal = pyqtSignal(float)


class Singleton(type):
    _instances = {}
    def __call__(cls, *args, **kwargs):
        if cls not in cls._instances:
            with lock:
                if cls not in cls._instances:
                    cls._instances[cls] = super(Singleton, cls).__call__(
                        *args, **kwargs)
        return cls._instances[cls]


class SimulationGenerator(object):
    def __init__(self, signals):
        # initial values from the control widget
        self.signals = signals
        self.peak_intensity = 10
        self.radius = 0.025
        self.center = 0.03
        self.max = 10
        self.bg = 0.05

    def sim(self):
        val = dict()
        val["i0"] = self.peak_intensity
        val["diff"] = self.max
        val["ratio"] = 1
        a = random.random()
        c = random.random()
        val["diff"] = self.bg * (1 + (a - 0.5))
        val["i0"] = self.peak_intensity * 1 + self.bg * (c - 0.5)
        val["ratio"] = val["diff"] / val["i0"]
        return val['ratio']


class ValueReader(metaclass=Singleton):
    def __init__(self, signals):
        self.signals = signals
        self.ratio = 1
        self.simgen = SimulationGenerator(self.signals)

    def read_value(self):
        log.debug("read value from ValueReader")
        self.ratio = self.simgen.sim()
        return self.ratio


class ThreadWorker(QObject):
    def __init__(self, signals):
        log.info("Inside thread init")
        super(ThreadWorker, self).__init__()
        self.signals = signals
        self.mode = "running"
        self.paused = True
        self._count = 0
        self.refresh_rate = 1
        self.average = 0
        self.current_value = 0
        self.ratio = 0
        self.buffer = deque([np.nan], 100)
        self.cal_vals = []
        self.reader = ValueReader(self.signals)
        self.signals.changeMode.connect(self.set_mode)
        self.signals.start.connect(self.start_it)
        self.signals.stop.connect(self.stop_it)

    def set_mode(self, m):
        log.info("Inside of the set_mode method of ThreadWorker.")
        self.mode = m

    def start_com(self):
        log.info("Inside of the start_com method of ThreadWorker.")
        while not self.thread().isInterruptionRequested():
            self.run_data_thread()
            time.sleep(1/10)  # not sure if this is even necessary

    def start_it(self):
        log.info("Inside of the start_it method of ThreadWorker.")
        self.paused = False

    def stop_it(self, abort):
        self.paused = True
        if abort:
            print('abort')
            self.thread().requestInterruption()
        else:
            print('Pause')

    def run_data_thread(self):
        """Long-running task to collect data points"""
        while not self.paused:
            log.info("Inside of the run method of ThreadWorker.")
            self.current_value = self.reader.read_value()
            if self.mode == "running":
                self.ratio = self.current_value
                self.buffer.append(self.current_value)
                self.signals.refreshGraphs.emit(self.ratio, self._count)
                time.sleep(1 / self.refresh_rate)
            elif self.mode == "calibrate":
                self.ratio = self.current_value
                self.buffer.append(self.current_value)
                self.signals.refreshGraphs.emit(self.ratio, self._count)
                self.calibrate(self.current_value)
                time.sleep(1 / self.refresh_rate)
            self._count += 1
            if self._count == 100:
                self._count = 0

    def calibrate(self, ratio):
        log.info("Inside of the calibrate method of ThreadWorker.")
        self.cal_vals.append(ratio)
        if len(self.cal_vals) > 50:
            self.signals.calibrationVal.emit(mean(self.cal_vals))
            self.cal_vals = [[], [], []]
            self.mode = "running"


class MainWindow(QMainWindow):

    def __init__(self, *args, **kwargs):
        log.info("Inside main window")
        super(MainWindow, self).__init__(*args, **kwargs)
        self.qw = QWidget()
        self.setCentralWidget(self.qw)
        self.signals = Signals()
        self.x = list(range(100))  # 100 time points
        self.y = [np.nan for _ in range(100)]
        self.cal_y = 0

        self.layout1 = QHBoxLayout()
        self.qw.setLayout(self.layout1)
        self.layout2 = QVBoxLayout()

        # make buttons/controls
        self.start = QPushButton("Start")
        self.stop = QPushButton("Stop")
        self.cal = QPushButton("Calibrate")

        self.layout2.addWidget(self.start)
        self.layout2.addWidget(self.stop)
        self.layout2.addWidget(self.cal)
        self.layout1.addLayout(self.layout2)

        # make graph
        self.graphWidget = pg.PlotWidget()
        self.data_line = self.graphWidget.plot(self.x, self.y)
        styles = {"color": "#f00", "font-size": "20px"}
        self.graphWidget.setLabel("left", "Signal", **styles)
        self.graphWidget.setLabel("bottom", "Seconds", **styles)
        self.graphWidget.addLegend()
        self.graphWidget.showGrid(x=True, y=True)
        self.layout1.addWidget(self.graphWidget)
        
        # start thread
        self.thread1 = QThread()
        self.thread_worker = ThreadWorker(self.signals)
        self.thread_worker.moveToThread(self.thread1)
        self.signals.refreshGraphs.connect(self.plot_data)
        self.thread1.started.connect(self.thread_worker.start_com)
        self.start.clicked.connect(self._start)
        self.stop.clicked.connect(self._stop)
        self.cal.clicked.connect(self._calibrate)
        self.signals.calibrationVal.connect(self.add_cal_plot)
        self.thread1.start()

    def plot_data(self, r, count):
        self.y[count] = r
        self.data_line.setData(self.x, self.y)

    def _start(self):
        if self.thread1.isRunning():
            self.signals.start.emit()
        else:
            self.thread1.start()

    def _stop(self):
        if not self.thread_worker.paused:
            self.signals.stop.emit(False)

    def _calibrate(self):
        if self.thread_worker.paused:
            print("You are not running so there's \
                  nothing to calibrate.. maybe hit start first")
        else:
            self.signals.changeMode.emit("calibrate")

    def add_cal_plot(self, cal_val):
        if not self.cal_y:
            self.cal_y = [cal_val for _ in range(100)]
            self.cal_line = self.graphWidget.plot(self.x, self.cal_y)
        else:
            self.cal_y = [cal_val for _ in range(100)]
            self.cal_line.setData(self.x, self.cal_y)

    def closeEvent(self, event):
        self.signals.stop.emit(True)
        self.thread1.quit()
        self.thread1.wait()


class App(QApplication):
    def __init__(self, sys_argv):
        super(App, self).__init__(sys_argv)
        log.debug("Supplying Thread information from init of QApplication")
        self.setStyle("Fusion")
        self.mainWindow = MainWindow()
        self.mainWindow.setWindowTitle("jet-tracker")
        self.mainWindow.show()


def main():
    app = App(sys.argv)
    sys.exit(app.exec_())


if __name__ == '__main__':
    main()

Solution

  • You should understand the thread affinity of the QObject. If you want slots of the ThreadWorker to be executed on the worker thread, you should connect signals after the thread affinity has already changed by the QObject.moveThread().

    So, do like this.

    ...
    class ThreadWorker(QObject):
        def init_after_move(self, signals):
            self.signals = signals
            ...
            self.signals.changeMode.connect(self.set_mode)
            self.signals.start.connect(self.start_it)
            self.signals.stop.connect(self.stop_it)
        ...
    ...
    class MainWindow(QMainWindow):
        ...
        def __init__(self, *args, **kwargs):
            ...
            self.thread_worker = ThreadWorker()
            self.thread_worker.moveToThread(self.thread1)
            self.thread_worker.init_after_move(self.signals)
            ...
    ...
    

    In this way, slots of the ThreadWorker will be executed on the worker thread. Thus, this will also answer your second question.

    And finally, it's a bad idea to define various signals inside a single container, because it hurts readability and maintainability.