Search code examples
pythonfor-loopmatplotlibplotsubplot

Subplots repeating the same graph 6 times and producing 6 figures instead of one


So I have this code:

def scatter(df, column_name):
  values = {data: list(df[data]) for data in column_name}

  data = list(values.values())
  labels = list(values.keys())
  
  for i in range(len(data)):
    for j in range(len(data)):
      if i == j:
        continue
      elif (i == 1) & (j == 0):
        continue
      elif (i == 2) & ((j == 0)|(j == 1)):
        continue
      elif (i == 3) & ((j == 0)|(j == 1)|(j == 2)):
        continue
      else:
        for k in range(6):
          ax = plt.subplot(3, 2, k+1)
          plt.scatter(data[i], data[j])
          plt.xlabel(labels[i])
          plt.ylabel(labels[j])
          plt.title('{} vs {}'.format(labels[i], labels[j]))
        plt.show()
        plt.clf()

scatter(roller_coasters, ['speed', 'height', 'length', 'num_inversions'])

but it produces 6 figures instead of 1 and each figure has the same graph repeated 6 times.

Please help me solve this.


Solution

  • Well for each time you enter the else part of your loop, you create 6 subplots for that given i,j combination. E.g. for i=0; j=1 the loop for k creates six subplots but only for that specific i and j. And when, created, the figure is closed again (plt.clf()). The following i=0; j=2 a next set of 6 subplots is created.

    You can simplify things by letting the loop over j start at i+1, so no tests are needed. Also, the value for which subplot will be created next, can be a variable k that is incremented each time a subplot has been added.

    Here is some example code:

    from matplotlib import pyplot as plt
    import pandas as pd
    import numpy as np
    
    def scatter(df, column_names):
        fig = plt.figure(figsize=(10, 12)) # set a size for the surrounding plot
        n = len(column_names)
        total = n * (n - 1) // 2
        ncols = 2
        nrows = (total + (ncols - 1)) // ncols
        k = 1
        for i in range(n):
            col_i = column_names[i]
            for j in range(i + 1, n):
                col_j = column_names[j]
                ax = plt.subplot(nrows, ncols, k)
                plt.scatter(df[col_i], df[col_j])
                plt.xlabel(col_i)
                plt.ylabel(col_j)
                plt.title(f'{col_i} vs {col_j}')
                k += 1
        plt.tight_layout() # fit labels and ticks nicely together
        plt.show() # only called once, at the end of the function
    
    columns = ['speed', 'height', 'length', 'num_inversions']
    roller_coasters = pd.DataFrame(np.random.rand(20, len(columns)), columns=columns)
    scatter(roller_coasters, ['speed', 'height', 'length', 'num_inversions'])
    

    6 subplots