I am trying to build a multilabel regression model with Keras. The labels have a number of NA values, i.e. not all instances were tested for all label. Here's a sample of my code:
import numpy as np
import pandas as pd
from sklearn.datasets import make_multilabel_classification
X, _ = make_multilabel_classification(n_samples = 1000, sparse = True, n_features = 40,
return_indicator = 'sparse', allow_unlabeled = False)
y = pd.DataFrame(np.random.randint(0, 100, (1000, 10)))
na_ = np.random.choice([True, False], size=y.shape)
na_[na_.all(1),-1] = 0
y = y.mask(na_)
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation
from keras.callbacks import ReduceLROnPlateau
from keras import regularizers
from keras.optimizers import RMSprop, Adam, SGD
sgd = SGD(lr=0.01, momentum=0.9, nesterov=True)
model = Sequential()
model.add(Dense(100, input_dim=40))
model.add(Activation('relu'))
model.add(Dense(10))
model.compile(loss='mean_squared_error', optimizer=sgd, metrics=['mae'])
hist = model.fit(X_train, y_train, epochs=500, verbose=1, validation_split=0.2)
scores = model.evaluate(X_test, y_test)
The multilabel to be predicted (y) looks like this:
0 NaN NaN 4.0 NaN NaN NaN NaN 35.0 NaN 98.0
1 NaN NaN 70.0 17.0 NaN NaN 4.0 69.0 33.0 NaN
2 14.0 NaN NaN 65.0 NaN NaN NaN 50.0 64.0 55.0
3 78.0 NaN 2.0 NaN 44.0 79.0 67.0 43.0 3.0 64.0
4 NaN 54.0 NaN NaN NaN 67.0 18.0 39.0 3.0 41.0
I need to alter the loss function by introducing a mask variable that masks out all NaN labels but I have been unable to implement this. Kindly assist!
You can mask the labels using tf.where()
function like this:
import tensorflow as tf
from keras import backend as K
def mse_nan(y_true, y_pred):
masked_true = tf.where(tf.is_nan(y_true), tf.zeros_like(y_true), y_true)
masked_pred = tf.where(tf.is_nan(y_true), tf.zeros_like(y_true), y_pred)
return K.mean(K.square(masked_pred - masked_true), axis=-1)