Search code examples
pythonpandasdataframesplittraining-data

split data such that a a categorical value is either in train or test python


I have a data set (df) as follows

Company Col1 Col2 Output     
AB       10    20   1
AB       20    22   1
AB       14    12   0
XZ       33    22   1
XZ       43    62   0

I want to train_test_split the data such that if a company is in the test set, it should not be in the training set at all. By which I mean if the first row ( AB, 10, 20,1) is in the test set, the second row ( AB, 20,22,1) should also be in the test set. I know stratify would stratify=df[["Name"] would do the exact opposite of what I want. Is there any built in function to do as such?

P.S. Company column is string


Solution

  • This might be a little verbose and not a generic function, but this approach might work for you:

    counts = df.groupby("Company").count()["Output"]
    frac = 0.8 # Fraction of the training table, will only be approximated
    train_companies = []
    i = 0 
    c = 0 
    total_count = counts.values.sum()
    train_count = total_count * frac 
    while(c < train_count): 
      train_companies.append(counts.index[i])
      c = c + counts.values[i]
      i = i + 1 
      c = c + counts.values[i]
    
    df_train = df[df['Company'].isin(train_companies)]
    df_test = df[~df['Company'].isin(train_companies)]