I want to apply Leave one pair cross validation(LPOCV) on a binary classification problem. For each one pair sample selected as holdout/test pair,it should be one sample from each binary class.
My code is like:
from sklearn.model_selection import LeavePOut
import numpy as np
X = np.array([[1, 2], [3, 4], [5, 6], [7, 8],[9,10]])
y = np.array([0,1,1,0,0])
lpo = LeavePOut(2)
print(lpo.get_n_splits(X))
print(lpo)
LeavePOut(p=2)
for train_index, test_index in lpo.split(X):
print("TRAIN:", train_index, "TEST:", test_index)
X_train, X_test = X[train_index], X[test_index]
y_train, y_test = y[train_index], y[test_index]
The output is like:
LeavePOut(p=2)
TRAIN: [2 3 4] TEST: [0 1]
TRAIN: [1 3 4] TEST: [0 2]
TRAIN: [1 2 4] TEST: [0 3]
TRAIN: [1 2 3] TEST: [0 4]
TRAIN: [0 3 4] TEST: [1 2]
TRAIN: [0 2 4] TEST: [1 3]
TRAIN: [0 2 3] TEST: [1 4]
TRAIN: [0 1 4] TEST: [2 3]
TRAIN: [0 1 3] TEST: [2 4]
TRAIN: [0 1 2] TEST: [3 4]
The test pair [0 3] and [0 4] belongs to same class 0. Is their any way to split X data with test pair comprising of samples from both 0 and 1 class?
I think you could adjust your code so that all folds in which the test set features only indices of one class (e.g. class 0) are omitted:
from sklearn.model_selection import LeavePOut
import numpy as np
X = np.array([[1, 2], [3, 4], [5, 6], [7, 8],[9,10]])
y = np.array([0,1,1,0,0])
lpo = LeavePOut(2)
print(lpo.get_n_splits(X))
print(lpo)
LeavePOut(p=2)
for train_index, test_index in lpo.split(X):
for x in range(0,len(test_index)):
for z in range(1,len(test_index)):
if(y[test_index[x]] != y[test_index[z]]):
print("TRAIN:", train_index, "TEST:", test_index)
X_train, X_test = X[train_index], X[test_index]
y_train, y_test = y[train_index], y[test_index]
Then the output is:
LeavePOut(p=2)
TRAIN: [2 3 4] TEST: [0 1]
TRAIN: [1 3 4] TEST: [0 2]
TRAIN: [0 2 4] TEST: [1 3]
TRAIN: [0 2 3] TEST: [1 4]
TRAIN: [0 1 4] TEST: [2 3]
TRAIN: [0 1 3] TEST: [2 4]
And the removed folds are the ones, where the test set only represented one class, namely with the pairs [0 3], [0 4], [1 2] and [3 4]