Search code examples
pythonscikit-learnkeyword-argument

Wrapper for test_train_split to produce train, validation, and test splits for any number of input arrays


What is the right way to build wrappers around the test_train_split function with *args and **kwargs? To give more context, data science often require to create a test-validate-train split, so I thought to build a wrapper like

def train_validate_test_split(*dataframe, **options):
   train, test = train_test_split(dataframe, options)
   train, val = train_test_split(train, options)
   return train, val, test

that gives a train, validation, test split of the dataset from oneliner calls. However, executing

train_validate_test_split(dataframe_1, test_size = 0.2)

leads to a catastrophic failure. I guess that I am messing *args and **kwargs quite spectacularly, but I still have problems in putting my head around them. Any suggestion would be greatly appreciated.


Solution

  • The function signature is:

    train_test_split(*arrays, **options)
    

    meaning it accepts any number of positional arrays and any number of keyword options. To return train, val, test as you wish, one would proceed as follows:

    from sklearn.model_selection import train_test_split
    df = pd.DataFrame({"x": np.random.randn(1000),"y": np.random.randn(1000)})
    
    def train_validate_test_split(dataframe, **options):
       train, test = train_test_split(dataframe, **options)
       train, val = train_test_split(train, **options)
       return train, val, test
    
    a,b,c = train_validate_test_split(df, train_size=.25)
    

    EDIT

    To accept either 1 or 2 inputs use:

    def train_val_test_split(*arrays,**options):
    
        if len(arrays) == 1:
            X_train, X_test = train_test_split(*arrays,**options)
            X_train, X_val = train_test_split(X_train,**options)
            print("Unpack to X_train, X_val, X_test")
            return X_train, X_val, X_test
    
        if len(arrays) == 2:
            X_train, X_test, y_train, y_test = train_test_split(*arrays,**options)
            X_train, X_val, y_train, y_val = train_test_split(X_train,y_train,**options)
            print("Unpack to X_train, X_val, X_test, y_train, y_val, y_test")
            return X_train, X_val, X_test, y_train, y_val, y_test
    
        else:
            raise ValueError("Only implemented for 1 or 2 arrays. "
                              f"You provided {len(arrays)} arrays")
    

    or for any number of input arrays:

    y = np.random.randn(1000)
    def train_val_test_split(*arrays,**options):
        '''
        inputs:
            arrays - any number of array to split,
        outputs:
            sequence 
            arr1_train, arr2_train, ... , arr1_val , arr2_val, ..., arr1_test, arr2_test, ...
        '''
        *out, = train_test_split(*arrays,**options)
        train = out[0::2] #x1_train, x2_train, ...
        test  = out[1::2] #x1_test, x2_test, ...
        *train_val, = train_test_split(*train,**options)
        train = train_val[0::2]
        val   = train_val[1::2]
        print(f"Unpack to {len(arrays)*3} tuples: train,...,val,..., test...")
        return tuple(split for tuple_ in zip(train,val,test) for split in tuple_)
    
    x = train_val_test_split(y,y,y)
    
    for item in x:
        print(item.shape, end=", ")
    

    Unpack to 9 tuples: train,...,val,..., test...
    (562,), (188,), (250,), (562,), (188,), (250,), (562,), (188,), (250,),