I have a dataframe
with one column called "label", which represents a binary feature [0,1].
The dataframe is imbalanced, with more labels 0 than 1s, therefore, to build a good estimator, I want to split the data into training and testing subsets, where the training subset has to be well balanced. I could try using resample algorithms like SMOTE or others; how ever, I decided to go with the following stratergy:
Select all those rows of dataframe
with the label 1 and make from that a random subselection with the 80%, like:
train_class1=dataframe[dataframe["label"]==1].iloc[np.random.randint(0, len(dataframe[dataframe["label"]==1]), len(dataframe[dataframe["label"]==1])*80//100)]
Then, from the rows with label 0, I did a random subselection of the same size as train_class1
and I called it train_class0
, like:
train_class0=dataframe[dataframe["label"]==0].iloc[np.random.randint(0, len(dataframe[dataframe["label"]==0]), len(dataframe[dataframe["label"]==1])*80//100)]
So I was planning to concatenate by rows both dataframes to be my training subsample:
train_class=pd.concat([train_class1,train_class0])
Now, as testing subsample I want it to be the rest of the initial dataframe
, this is: all those rows of dataframe
that don't belong to train_class
. I tried the following:
test_class =pd.concat([dataframe, train_class]).drop_duplicates()
to concatenate the initial dataframe
with train_class
and remove the duplicate rows.
However this looks normal (at least to me at this point), when I check the shapes of dataframe
, train_class
and test_class
, I get:
dataframe.shape=(257673, 208)
train_class.shape=(263476, 208)
test_class.shape=(257673, 208)
which is obviously contradictory.
What I am doing wrong in the code?
I actually solved the problem...
It was in the definition of train_class1 and train_class0, that I changed to:
train_class1=dataframe[dataframe["label"]==1].sample(len(dataframe[dataframe["label"]==0])*80//100)
train_class0=dataframe[dataframe["label"]==0].sample(len(dataframe[dataframe["label"]==0])*80//100)
by using the built-in pandas function df.sample().