Search code examples
pythonpython-3.xmatplotlibplotaxis-labels

X and Y Ticks on a 4x4 multiplot using matplotlib in Python


I am using Python3.6.5 from an Anaconda install.

I have 16 data files containing two columns of data. I am trying to make a plot that shows all the data in one 4x4 plot. I have managed to get all the plots plotted on a large 4x4 plot, but can't adjust the X and Y ticks. The X values range from 0 to 2000 and the Y values range from 0 to 4.5.

This is my current script:

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import math

ph_values = [1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.5, 7.0, 7.5, 8.0, 8.5]

all_xs = []
all_ys = []
for ph in ph_values:
    xs = []
    ys = []
    with open('rmsd_ph' + str(ph) + '.dat', "r") as f:
        for line in f:
            if line[0] != "#":
               x,y = line.split()
               xs.append(float(x))
               ys.append(float(y))
    all_xs.append(xs)
    all_ys.append(ys)

fig, axes = plt.subplots(nrows=math.ceil(len(ph_values)/4), ncols=4, figsize=(6,6))

axes = axes.flatten()
for index,ph in enumerate(ph_values):
    axes[index].plot(np.asarray(all_xs[index]),np.asarray(all_ys[index]))

plt.xticks(np.arange(0, 2000, step=500))
plt.tight_layout()
plt.savefig('test.pdf')
plt.show()

Currently the script outputs something that looks like this.

enter image description here

As you can see the last plot has the X-axis adjusted. I have not tried to adjust the Y-axis yet because I have not not been successful with the y axis.

Overall, I would like 4 ticks on both the y and x axis.


Solution

  • This is what I found that answered the problem I was having.

    fig, axes = plt.subplots(nrows=math.ceil(len(ph_values)/4), ncols=4, figsize=(9,9))
    
    axes = axes.flatten()
    for index,ph in enumerate(ph_values):
        axes[index].scatter(np.asarray(all_xs[index]),np.asarray(all_ys[index]), s=1)
        plt.sca(axes[index])  <------------------  Fixed Problem 
        plt.xticks([0, 500, 1000, 1500, 2000]) <-  Fixed Problem
        plt.yticks([0, 1, 2, 3, 4, 5]) <---------- Fixed Problem
        plt.title('pH:' + str(ph)) 
        if (index % 4 == 0):
            plt.ylabel('RMSD [$\\rm{\\AA}$]')
        if (index >= 12):
            plt.xlabel('Steps')
    
    
    plt.tight_layout()
    plt.savefig(output)
    plt.show()
    

    Here is an image of the result.

    Final Image