Search code examples
pythonpandaspivot-tablenanimputation

How to remove NaN values from pivot table only if each column has more than x NaN values?


I have a pivot table that I create with the line

pivot_table = pd.pivot_table(df, values='trip_duration', index=['day_of_month', 'hour'], columns='location_pair_id', aggfunc=np.mean, dropna=True), which looks like this:

pivot table

For each column, I want to impute the NaN values, but only if the entire column has less than x NaN values, say x=10. All the other columns having NaN values more than x times, should be removed.

Until now, I tried to add a subset of columns into the dropna function:

pivot_table = pivot_table.dropna(axis=1, subset=nan_values_idx),

where nan_values_idx is calculated as follows:

nan_values = pivot_table.isnull().sum()
nan_values_idx = list(nan_values[nan_values>10].keys()),

which gives a list of location_pair_id's: ['(164, 170)', '(186, 230)', '(186, 48)',...,'(79, 79)']

However, when I say

pivot_table = pivot_table.dropna(axis=1, subset=nan_values_idx)

I get the error:

in DataFrame.dropna(self, axis, how, thresh, subset, inplace)
   6548     check = indices == -1
   6549     if check.any():
-> 6550         raise KeyError(np.array(subset)[check].tolist())
   6551     agg_obj = self.take(indices, axis=agg_axis)
   6553 if thresh is not no_default:

KeyError: ['(164, 170)', '(186, 230)', '(186, 48)', '(186, 68)', '(230, 186)', '(230, 230)', '(230, 48)', '(230, 50)', '(263, 141)', '(263, 75)', '(48, 142)', '(48, 164)', '(48, 186)', '(48, 230)', '(48, 48)', '(48, 50)', '(48, 68)', '(68, 246)', '(68, 48)', '(68, 68)', '(79, 107)', '(79, 79)']

I appreciate any hint!


Solution

  • You can count number of NaN values in each column, and filter out the column if the number is above 10 (or another value)

    cols = [col for col, no_na in pivot_table.isna().sum().items() if no_na <= 10]
    pivot_table = pivot_table[cols]