Search code examples
pandaslistloopsdataframecontains

Fast remove element of list if contained by pandas dataframe


I have a list of strings, and two separate pandas dataframes. One of the dataframes contains NaNs. I am trying to find a fast way of checking if any item in the list is contained in either of the dataframes, and if so, to remove it from the list.

Currently, I do this with list comprehension. I first concatenate the two dataframes. I then loop through the list, and using an if statement check if it is contained in the concatenated dataframe values.

patches = [patch for patch in patches if not patch in bad_patches.values]

The first 5 elements of my list of strings:

patches[1:5]
['S2A_MSIL2A_20170613T101031_11_52',
 'S2A_MSIL2A_20170717T113321_35_89',
 'S2A_MSIL2A_20170613T101031_12_39',
 'S2A_MSIL2A_20170613T101031_11_77']

An example of one of my dataframes, with the second being the same but containing less rows. Note first row contains patches[2].

cloud_patches.head()
0  S2A_MSIL2A_20170717T113321_35_89

1  S2A_MSIL2A_20170717T113321_39_84

2   S2B_MSIL2A_20171112T114339_0_13

3   S2B_MSIL2A_20171112T114339_0_52

4   S2B_MSIL2A_20171112T114339_0_53

The concatenated dataframe:

bad_patches = pd.concat([cloud_patches, snow_patches], axis=1)
bad_patches.head()
0  S2A_MSIL2A_20170717T113321_35_89  S2B_MSIL2A_20170831T095029_27_76

1  S2A_MSIL2A_20170717T113321_39_84  S2B_MSIL2A_20170831T095029_27_85

2   S2B_MSIL2A_20171112T114339_0_13  S2B_MSIL2A_20170831T095029_29_75

3   S2B_MSIL2A_20171112T114339_0_52  S2B_MSIL2A_20170831T095029_30_75

4   S2B_MSIL2A_20171112T114339_0_53  S2B_MSIL2A_20170831T095029_30_78

and the tail, showing the NaNs of one column:

bad_patches.tail()
61702  NaN   S2A_MSIL2A_20180228T101021_43_6

61703  NaN   S2A_MSIL2A_20180228T101021_43_8

61704  NaN  S2A_MSIL2A_20180228T101021_43_11

61705  NaN  S2A_MSIL2A_20180228T101021_43_13

61706  NaN  S2A_MSIL2A_20180228T101021_43_16

Column headers are all (poorly) named 0.

The second element of patches should be removed as it's contained in the first row of bad_patches. My method does work but takes absolutely ages. Bad_patches is 60,000 rows and the length of patches is variable. Right now for a length of 1000 patches it takes a 2.04 seconds but I need to scale up to 500k patches so hoping there is a faster way. Thanks!


Solution

  • I would create a set with the values from cloud_patches and snow_patches. Then also create a set of patches:

    patch_set = set(cloud_patches[0]).union(set(snow_patches[0])
    patches = set(patches)
    

    Now you just subtract all values in patch_set from the values in patches, and you will be left with only values in patches that do not show up in cloud_patches nor snow_patches:

    cleaned_list = list(patches - patch_set)