Search code examples
pandasdataframenumpyrandom

How to randomly sample balanced pairs of rows from a Pandas DataFrame


Suppose I have a dataset which contains labels, filenames, and potentially other columns of metadata. The dataset may have as many as 200,000 examples. I've provided a snippet below that simulates this setup.

import pandas as pd
import numpy as np
import IPython.display as ipd

size = 20000
df = []
rng = np.random.default_rng(0)
for i in range(size):
    l = rng.choice(('cat', 'dog', 'mouse', 'bird', 'horse', 'lion', 'rabbit'))
    fp = str(rng.integers(1e5)).zfill(6) + '.jpg'
    df.append((l, fp))
df = pd.DataFrame(df, columns=['label', 'filepath'])
ipd.display(df)

I would like to efficiently produce N randomly generated pairs of data, with the condition that the dataset is balanced between positive and negative pairs, e.g.,

# df_out would be of size "N"
df_out = pd.DataFrame([], columns=['label_1', 'label_2', 'filepath_1', 'filepath_2'])

Here I am defining a positive pair as one where label_1 equals label_2, and a negative pair as one where the two labels are not equal. So the goal is for df_out to contain roughly 50%-positive and 50%-negative pairs.

The first approach I tried works by sampling 2N rows from the DataFrame, then collapses them into pairs.

N = 20
ii = rng.permutation(np.arange(N*2)%len(df))
func = lambda x: x.dropna().astype(str).str.cat(sep=',')
df_out = df.iloc[ii].reset_index(drop=True)  # subsample 
df_out = df_out.groupby(df_out.index//2)  # collapse every two rows into one row
df_out = df_out.agg(func).reset_index(drop=True)  # use `func` to combine rows
for k in df.columns:
    df_out[[f'{k}_1',f'{k}_2']] = df_out[k].str.split(',', expand=True)
    del df_out[k]

So this works to make pairs of rows, but it doesn't take any consideration to positive or negative pairs.

# as one would expect, this percentage is not equal to 50%
print(sum(df_out.eval('label_1==label_2')) / N)

Solution

  • Here is an approach by shuffling the data and grouping the rows either:

    • by pairs of identical values
    • by pairs of different values

    Then pivoting the data and sampling again randomly.

    N = 100 # number of rows to pick (half positive, half negative)
    
    #### positive pairs
    
    df2 = (df.sample(frac=1)
             .assign(n=lambda d: d.groupby('label').cumcount(),
                     n2=lambda d: d['n'].floordiv(2),
                     col=lambda d: d['n'].mod(2).add(1),
                    )
           )
    
    positives = (df2[df2.duplicated(['label', 'n2'], keep=False)]
     .reset_index()
     .pivot(index=['n2', 'label'], columns='col', values=['label', 'filenames', 'index'])
     .sample(n=N//2)
     .reset_index(drop=True)
    )
    
    positive_idx = positives.pop('index').stack().values
    
    #### negative pairs
    
    negatives = (
      df.drop(positive_idx)  # comment the "drop" if you don't want to exclude row picked above
        .sample(frac=1)
        .assign(n=lambda d: d.groupby('label').cumcount(),
                g=lambda d: d.groupby('n').cumcount().floordiv(2),
                col=lambda d: d.groupby('n').cumcount().mod(2).add(1),
               )
        .pivot(index=['n', 'g'], columns='col', values=['label', 'filenames'])
        .dropna().sample(n=N//2)
        .reset_index(drop=True)
    )
    
    out = pd.concat({'positives': positives, 'negatives': negatives})
    
    print(out)
    

    Output:

                   label           filenames            
    col                1       2           1           2
    positives 0     bird    bird  095459.jpg  026617.jpg
              1    horse   horse  062451.jpg  027905.jpg
              2   rabbit  rabbit  067629.jpg  065238.jpg
              3    horse   horse  024818.jpg  026751.jpg
              4      cat     cat  007291.jpg  048994.jpg
    ...              ...     ...         ...         ...
    negatives 45  rabbit     cat  010290.jpg  044769.jpg
              46   mouse    bird  016260.jpg  098423.jpg
              47   mouse   horse  044362.jpg  065754.jpg
              48     dog     cat  085628.jpg  058504.jpg
              49   horse    bird  061706.jpg  025309.jpg
    
    [100 rows x 4 columns]