Search code examples
pythonnumpymatplotlibplots.jl

Stack plots generated in a loop


I am running an if loop and want to stack the resulting plots in a grid. This is my sample code, which generates two random variables and plots a third one under two conditions:

import numpy as np
import matplotlib.pyplot as plt

# Set seed for reproducibility
np.random.seed(42)

# Generate 10 realizations of the uniform distribution between -1 and 1 for x and y
x = np.random.uniform(low=-1, high=1, size=10)
y = np.random.uniform(low=-1, high=1, size=10)

# Create empty list to store valid plots
valid_plots = []

# Initialize an empty list to store the current row of plots
current_row = []

# Loop through each realization of x and y
for i in range(len(x)):
    # Check if both x and y are positive
    if x[i] > 0 and y[i] < 0:
        # Generate 100 values of z
        z = np.linspace(-1, 1, 100)
        # Compute the function z = xy*z^2
        z_func = x[i] * y[i] * z*z
        # Plot the function
        fig, ax = plt.subplots()
        ax.plot(z, z_func)
        # If there are now two plots in the current row, append the row to valid_plots and start a new row
        if len(current_row) % 2 == 1:
            valid_plots.append(current_row)
            current_row = []
        # Append the current plot to the current row
        current_row.append(ax)
# If there is only one plot in the last row, append the row to valid_plots
if len(current_row) > 0 and len(current_row) % 2 == 1:
    current_row.append(plt.gca())
    valid_plots.append(current_row)

# Create a figure with subplots for each valid plot
num_rows = len(valid_plots)
fig, axes = plt.subplots(num_rows, 2, figsize=(12, 4 * num_rows))
for i, row in enumerate(valid_plots):
    for j, ax in enumerate(row):
        # Check if the plot has any lines before accessing ax.lines[0]
        if len(ax.lines) > 0:
            axes[i, j].plot(ax.lines[0].get_xdata(), ax.lines[0].get_ydata())
plt.show()

The problem with the ouput is that it generates two empty graphs and then starts stacking up vertically:

output

Could help me out? I would also be interested in more efficient methods of achieving this result.


Solution

  • Based on what I have understood, you are getting a random number of plots and you want to plot only those where x[i]>0 and y[i]<0 for each pair of x and y. You don't want any blank plots either.

    In your implementation, the main issue is that you are creating all the plots first and then trying to do a Nx2 plots. So, please refer to below code. I have first checked all the pairs where x[i]>0 and y[i]<0 and collected the index alone in an array called valid_ids.... not plotting anything yet.

    Then, as we know how many subplots are required from len(valid_ids), I am plotting it in one for loop. Note that math.ceil() will give me the ceiling of the number of plots by 2, so we get 2 rows for 4 plots or 3 rows for 5 plots. This way, we end up plotting all the zs as subplots in one shot. Finally, if there is an odd number of subplots, you can use fig.delaxes() to remove the extra plot. Hope this is what you are looking for...

    import numpy as np
    import matplotlib.pyplot as plt
    
    # Set seed for reproducibility
    np.random.seed(42)
    
    # Generate 10 realizations of the uniform distribution between -1 and 1 for x and y
    x = np.random.uniform(low=-1, high=1, size=10)
    y = np.random.uniform(low=-1, high=1, size=10)
    
    valid_ids=[]
    for i, (j, k) in enumerate(zip(x, y)):
        if j > 0 and k < 0:
            valid_ids.append(i)
    
    import math
    num_rows=math.ceil(len(valid_ids)/2)
    
    if len(valid_ids) > 0: ## No rows, don't plot anything 
        if len(valid_ids) > 2: ## More than one row, plot as axes[row, col]
            fig, axes = plt.subplots(num_rows, 2, figsize=(12, 4 * num_rows))
            for i in range(len(valid_ids)):
                # Generate 100 values of z
                z = np.linspace(-1, 1, 100)
                # Compute the function z = xy*z^2
                z_func = x[valid_ids[i]] * y[valid_ids[i]] * z*z
                # Plot the function
                axes[int(i/2), i%2].plot(z, z_func)
            # If ODD number of plots, then remove the last blank plot    
            if len(valid_ids)%2 == 1:
                fig.delaxes(axes[num_rows-1,1])
        else: ## Single row, plot as axes[col] 
            fig, axes = plt.subplots(num_rows, 2, figsize=(12, 4))
            for i in range(len(valid_ids)):
                # Generate 100 values of z
                z = np.linspace(-1, 1, 100)
                # Compute the function z = xy*z^2
                z_func = x[valid_ids[i]] * y[valid_ids[i]] * z*z
                # Plot the function
                axes[i].plot(z, z_func)
            # If just one plot, then remove the second one
            if len(valid_ids) == 1:
                fig.delaxes(axes[1])
            
        plt.show()
    

    Output plot - 3 plots in my case...

    enter image description here