Search code examples
pythonmachine-learningsplittraining-data

split into training set and test set with specific attribute values for rows


My input file is under the following form:

gold,Attribute1,Attribute2
T,1,1
T,1,2
T,1,1
N,1,2
N,2,1
T,2,1
T,2,2
N,2,2
T,3,1
N,3,2
N,3,1
T,3,2
N,3,3
N,3,3

I am trying to predict the first column using the second and third columns. I would like to split this input data randomly into a training set and a test set such that all the rows having a specific combination of the values of <attribute1, attribute2> fall either in the test set or the training set. For example, all the rows with values <1,1>, <1,2>, <2,1> should fall into the training set and all the rows with values <2,2>, <3,1>, <3,2>, <3,3> should fall in the test set. This has to be made randomly, this was just an example. How can I make such a split?


Solution

  • A simple way of spliting this will be through condition rather than pre-defined methods.

    Code :-

    import numpy as np
    import pandas as pd 
    
    df = pd.DataFrame(pd.read_csv('test.csv'))
    
    print(df.head())
    print(df.describe())
    print(type(df['Attribute1']))
    
    #For only getting values where both are less than 2 or equal to 2
    df_Condition1 = df[df['Attribute1'] <= 2]
    Train_Set = df_Condition1[df_Condition1['Attribute2'] <= 2]
    
    #to subract the remaining elements 
    Test_Set = df[ df.isin(Train_Set) == False]
    Test_Set =Test_Set.dropna()
    
    print(Train_Set)
    print(Test_Set)
    

    Output :

       gold  Attribute1  Attribute2
       0    T           1           1
       1    T           1           2
       2    T           1           1
       3    N           1           2
       4    N           2           1
      
       Attribute1  Attribute2
       count   14.000000   14.000000
       mean     2.142857    1.714286  
       std      0.864438    0.726273
       min      1.000000    1.000000 
       25%      1.250000    1.000000
       50%      2.000000    2.000000
       75%      3.000000    2.000000
       max      3.000000    3.000000
       <class 'pandas.core.series.Series'>
    
           gold  Attribute1  Attribute2
       0    T           1           1
       1    T           1           2
       2    T           1           1
       3    N           1           2
       4    N           2           1
       5    T           2           1
       6    T           2           2
       7    N           2           2
    
          gold  Attribute1  Attribute2
       8     T         3.0         1.0
       9     N         3.0         2.0
       10    N         3.0         1.0
       11    T         3.0         2.0
       12    N         3.0         3.0
       13    N         3.0         3.0