Search code examples
pythonpandastensorflowkerastensorflow-datasets

tf.int32 Being interpreted as tf.string when constructing a tensorflow dataset


I have a Pandas dataframe whose data I intend to export to a Tensorflow dataset. This dataframe has 4 columns, of which 2 of them are lists of strings and the remaining two are lists of integers. For the time being, the most important columns are input_ids and attention_mask, which constitute the input data of the model.

train_input_ids = train_df["input_ids"].values.tolist()
train_attention_mask = train_df["attention_mask"].values.tolist()

As the head() method shows, these columns store lists of ints. print(train_df["input_ids"].head(3)) returns the following:

0    [101, 24918, 7821, 5983, 46106, 21905, 10789...
1    [101, 33198, 10173, 14657, 25287, 55610, 10789...
2    [101, 10109, 19217, 34768, 16294, 17953, 51733...
Name: input_ids, dtype: object

One of the other columns, codes, stores a list of strings containing codes to be one-hot encoded:

from sklearn.preprocessing import LabelEncoder
label_encoder = LabelEncoder()
y_train = label_encoder.fit_transform(train_df["codes"].values)

I created the dataset as follows:

    train_dataset = (
    tf.data.Dataset
    .from_tensor_slices(((  tf.convert_to_tensor(train_input_ids), tf.convert_to_tensor(train_attention_mask)), y_train))
    .repeat()
    .shuffle(2048)
    .batch(BATCH_SIZE)
    .prefetch(BATCH_SIZE * 2)
    )

However, after trying to build the dataset from tensor slices and printing the resulting datasets, I've found that both train_input_idsand train_attention_mask are interpreted as tf.string:

<PrefetchDataset shapes: (((None,), (None,)), (None,)), types: ((tf.string, tf.string), tf.int64)>

As far as I know, they should be inferred as types: ((tf.int32, tf.int32) because the dataframe contains lists of ints as opposed to lists of strings (it is also the way the model accepts the inputs in the input layers definition). What am I missing?


Solution

  • I eventually discovered that at the end of the day I was having an issue very similar to the one described in this question. Therefore, since all of the columns of this dataframe store lists, I adapted one of the proposed solutions by defining a function which is to be called after retrieving the dataframe from its corresponding .tsv file:

    from ast import literal_eval
    
    def _remove_extra_quotes(df):
        def apply_lambda(column):
            df[column] = df[column].apply(lambda x: literal_eval(str(x)))
            return
        for column_name in df:
            apply_lambda(column_name)
        return df