Search code examples
pythonpandasmatplotlibpyqt5tablemodel

How to create a scatterplot GUI linked to change in a tablemodel(dataframe)


I have a GUI, which consist of a Qtableview and dataframe scatter plot widget. Plot draws the X, Y value from the table and has colormap with Z value. Here's the point. When I modify Y value in Qtableview, then press the refresh button on the bottom of GUI, the altered dataframe value can be seen on console. Of course, however, graphs are not updated in response to changes in data. How can I display the altered value on the graph when I pushed the refresh button?

import sys

import pandas as pd
from PyQt5 import QtCore, QtGui, QtWidgets
from PyQt5.QtCore import Qt, QSize
from PyQt5.QtWidgets import QApplication, QMainWindow, QVBoxLayout, QWidget, QPushButton


import matplotlib
import pandas as pd
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as canvas
from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as NavigationToolbar
from matplotlib.figure import Figure



class MplCanvas(canvas):
    def __init__(self, parent=None, width=5, height=4, dpi=100):
        fig = Figure(figsize=(width, height), dpi=dpi)
        self.axes = fig.add_subplot(111)
        super().__init__(fig)


class TableModel(QtCore.QAbstractTableModel):
    def __init__(self, data):
        super().__init__()
        self._data = data

    def data(self, index, role=Qt.DisplayRole):
        if index.isValid():
            if role == Qt.DisplayRole or role == Qt.EditRole:
                value = self._data.iloc[index.row(), index.column()]
                return str(value)

    def setData(self, index, value, role):
        if role == Qt.EditRole:
            self._data.iloc[index.row(), index.column()] = value
            return True
        return False

    def headerData(self, col, orientation, role):
        if orientation == Qt.Horizontal and role == Qt.DisplayRole:
            return self._data.columns[col]

    def flags(self, index):
        return Qt.ItemIsSelectable | Qt.ItemIsEnabled | Qt.ItemIsEditable

    def rowCount(self, index):
        return self._data.shape[0]

    def columnCount(self, index):
        return self._data.shape[1]

    def headerData(self, section, orientation, role):
        if role == Qt.DisplayRole:
            if orientation == Qt.Horizontal:
                return str(self._data.columns[section])

            if orientation == Qt.Vertical:
                return str(self._data.index[section])


class MainWindow(QtWidgets.QMainWindow):
    def __init__(self):
        super().__init__()

                
        layout = QVBoxLayout()

        self.table = QtWidgets.QTableView()

        self.data = pd.DataFrame(
            [[1, 9, 2], [2, 0, -1], [3, 5, 2], [4, 3, 2], [5, 8, 9],],
            columns=["X", "Y", "Z"])

        # Pandas data model setting
        self.model = TableModel(self.data)
        self.table.setModel(self.model)
                
        
                
        # define scatterplot
        sc = MplCanvas(self, width=5, height=4, dpi=100)
        self.data.plot.scatter(x='X', y='Y', c = 'Z', colormap = 'jet',  ax=sc.axes)
        toolbar = NavigationToolbar(sc, self)

        # Refresh button
        button_refresh = QPushButton("Refresh")
        button_refresh.clicked.connect(self.refresh_btn)

        layout.addWidget(self.table)
        layout.addWidget(toolbar)
        layout.addWidget(sc)
        layout.addWidget(button_refresh)


        
        widget = QWidget()
        widget.setLayout(layout)
        self.setCentralWidget(widget)
        

    def refresh_btn(self):
        print("data: \n", self.data)
        


app = QtWidgets.QApplication(sys.argv)
window = MainWindow()
window.show()
app.exec_()

enter image description here


Solution

  • First, You have to modify Your setData method to convert edited value into float. That way only numbers are stored in new data column (str is default). Then You have to update matplotlib plot again manually. You've got reference to the plot in self.data.plot object.

    Here is modified code, which is updating plot after Refresh button is clicked:

    import sys
    
    import pandas as pd
    from PyQt5 import QtCore, QtWidgets
    from PyQt5.QtCore import Qt
    from PyQt5.QtWidgets import QVBoxLayout, QWidget, QPushButton
    from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as canvas
    from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as NavigationToolbar
    from matplotlib.figure import Figure
    
    
    class MplCanvas(canvas):
        def __init__(self, parent=None, width=5, height=4, dpi=100):
            fig = Figure(figsize=(width, height), dpi=dpi)
            self.axes = fig.add_subplot(111)
            super().__init__(fig)
    
    
    class TableModel(QtCore.QAbstractTableModel):
        def __init__(self, data):
            super().__init__()
            self._data = data
    
        def data(self, index, role=Qt.DisplayRole):
            if index.isValid():
                if role == Qt.DisplayRole or role == Qt.EditRole:
                    value = self._data.iloc[index.row(), index.column()]
                    return str(value)
    
        def setData(self, index, value, role):
            if role == Qt.EditRole:
                self._data.iloc[index.row(), index.column()] = float(value)
                return True
            return False
    
        def headerData(self, col, orientation, role):
            if orientation == Qt.Horizontal and role == Qt.DisplayRole:
                return self._data.columns[col]
    
        def flags(self, index):
            return Qt.ItemIsSelectable | Qt.ItemIsEnabled | Qt.ItemIsEditable
    
        def rowCount(self, index):
            return self._data.shape[0]
    
        def columnCount(self, index):
            return self._data.shape[1]
    
        def headerData(self, section, orientation, role):
            if role == Qt.DisplayRole:
                if orientation == Qt.Horizontal:
                    return str(self._data.columns[section])
    
                if orientation == Qt.Vertical:
                    return str(self._data.index[section])
    
    
    class MainWindow(QtWidgets.QMainWindow):
        def __init__(self):
            super().__init__()
    
            layout = QVBoxLayout()
    
            self.table = QtWidgets.QTableView()
    
            self.data = pd.DataFrame(
                [[1, 9, 2], [2, 0, -1], [3, 5, 2], [4, 3, 2], [5, 8, 9], ],
                columns=["X", "Y", "Z"])
    
            # Pandas data model setting
            self.model = TableModel(self.data)
            self.table.setModel(self.model)
    
            # define scatterplot
            sc = MplCanvas(self, width=5, height=4, dpi=100)
            self.plot = self.data.plot.scatter(x='X', y='Y', c='Z', colormap='jet', ax=sc.axes)
            toolbar = NavigationToolbar(sc, self)
    
            # Refresh button
            button_refresh = QPushButton("Refresh")
            button_refresh.clicked.connect(self.refresh_btn)
    
            layout.addWidget(self.table)
            layout.addWidget(toolbar)
            layout.addWidget(sc)
            layout.addWidget(button_refresh)
    
            widget = QWidget()
            widget.setLayout(layout)
            self.setCentralWidget(widget)
    
        def refresh_btn(self):
            self.plot.clear()
            self.plot.scatter(self.data["X"], self.data["Y"], c=self.data["Z"], cmap="jet", s=20, alpha=0.9)
            self.plot.figure.canvas.draw()
            self.plot.figure.canvas.flush_events()
    
    
    app = QtWidgets.QApplication(sys.argv)
    window = MainWindow()
    window.show()
    app.exec_()