Search code examples
pythonmatplotliboptimizationsubplot

Is there a way to optimize this code for subplots in python?


I wrote a code to create 8 subplots for the net of each year in my data aggregated by month. I tried to optimize the code using two for loops but I dont know how to hundle the query part in the pd df. Is there a way to rewrite it in a better way or optimize this long code?

The VF_data is just a pandas dataframe with numerical Positive and negative values aggregated per month per year. Other columns are month, year, date.

Thank you all in advance!!

def plot_MTY(df, aggregate_col='NET'):   



plt.subplot(2, 4, 1)

VF_data=df.query("(YEAR == '2015')")

aggregated_target = aggregate_data(VF_data, 'DATES', aggregate_col)

plt.plot(aggregated_target, label = 'df', linestyle="-")

plt.axhline(y=0, color='b', linestyle='-')

locs, labels = plt.xticks()

plt.setp(labels, rotation=90)



plt.subplot(2, 4, 2)

VF_data=df.query("(YEAR == '2016')")

aggregated_target = aggregate_data(VF_data, 'DATES', aggregate_col)

plt.plot(aggregated_target, label = 'df', linestyle="-")

plt.axhline(y=0, color='b', linestyle='-')

locs, labels = plt.xticks()

plt.setp(labels, rotation=90)



plt.subplot(2, 4, 3)

VF_data=df.query("(YEAR == '2017')")

aggregated_target = aggregate_data(VF_data, 'DATES', aggregate_col)

plt.plot(aggregated_target, label = 'df', linestyle="-")

plt.axhline(y=0, color='b', linestyle='-')

locs, labels = plt.xticks()

plt.setp(labels, rotation=90)



plt.subplot(2, 4, 4)

VF_data=df.query("(YEAR == '2018')")

aggregated_target = aggregate_data(VF_data, 'DATES', aggregate_col)

plt.plot(aggregated_target, label = 'df', linestyle="-")

plt.axhline(y=0, color='b', linestyle='-')

locs, labels = plt.xticks()

plt.setp(labels, rotation=90)



plt.subplot(2, 4, 5)

VF_data=df.query("(YEAR == '2019')")

aggregated_target = aggregate_data(VF_data, 'DATES', aggregate_col)

plt.plot(aggregated_target, label = 'df', linestyle="-")

plt.axhline(y=0, color='b', linestyle='-')

locs, labels = plt.xticks()

plt.setp(labels, rotation=90)



plt.subplot(2, 4, 6)

VF_data=df.query("(YEAR == '2020')")

aggregated_target = aggregate_data(VF_data, 'DATES', aggregate_col)

plt.plot(aggregated_target, label = 'df', linestyle="-")

plt.axhline(y=0, color='b', linestyle='-')

locs, labels = plt.xticks()

plt.setp(labels, rotation=90)



plt.subplot(2, 4, 7)

VF_data=df.query("(YEAR == '2021')")

aggregated_target = aggregate_data(VF_data, 'DATES', aggregate_col)

plt.plot(aggregated_target, label = 'df', linestyle="-")

plt.axhline(y=0, color='b', linestyle='-')

locs, labels = plt.xticks()

plt.setp(labels, rotation=90)



plt.subplot(2, 4, 8)

VF_data=df.query("(YEAR == '2022')")

aggregated_target = aggregate_data(VF_data, 'DATES', aggregate_col)

plt.plot(aggregated_target, label = 'df', linestyle="-")

plt.axhline(y=0, color='b', linestyle='-')

locs, labels = plt.xticks()

plt.setp(labels, rotation=90)



plt.gcf().set_size_inches(15, 8)

plt.show()

Solution

  • You can loop through .groupby("YEAR")

    Below some example:

    df = pd.DataFrame({
        "YEAR": ["2022", "2022", "2023", "2023"],
        "x":[1, 2, 3, 4],
        "y": [1, 2, 3, 4]
    })
    
    for i, (year, gr) in enumerate(df.groupby("YEAR")):
        plt.subplot(1, 2, i+1)
        plt.plot(gr["x"], gr["y"])