Search code examples
pythonpandastestingtypesseries

What is the best way to check correct dtypes in a pandas dataframe as part of testing?


Before pre-processing and training a model on some data, I want to check that each feature (each column) of a dataframe is of the correct data type. i.e. if a dataframe has columns col1, col2, col3, they should have types int, float, string respectively as I have defined them (col1 can't be of type string, the order matters).

What is the best way to do this if

  1. The columns have various types - int, float, timestamp, string
  2. There are too many columns (>500) to manually write out / label each column data type

Something like

types = df.dtypes # returns a pandas series
if types != correct_types:
    raise TypeError("Some of the columns do not have the correct type")

Where correct_types are the known data types of each column - these would need to be in the same order as types to ensure each column type is correctly matched. It would also be good to know which column is throwing the error (so maybe a for loop over the columns is more appropriate?)

Is there any way to achieve this, and if so what is the best way to achieve this? Maybe I am looking at the issue the wrong way - more generally, how do I ensure that the columns of df are of the correct data type as I have defined them?


Solution

  • You can use pd.DataFrame.dtypes to return a series mapping column name to data type:

    df = pd.DataFrame([[1, True, 'dsfasd', 51.314],
                       [51, False, '56345', 56.1234]],
                      columns=['col1', 'col2', 'col3', 'col4'])
    
    res = df.dtypes
    
    print(res)
    
    col1      int64
    col2       bool
    col3     object
    col4    float64
    dtype: object
    

    The values of this series are dtype objects:

    print(res.iloc[0])
    
    dtype('int64')
    

    As a series, you can filter by index or by value. For example, to filter for int64 type:

    print(res[res == np.dtype('int64')])
    
    col1    int64
    dtype: object
    

    You can also compare the series to another via series1 == series2 to create a Boolean series mapping. A trivial example checking the series with itself:

    # in reality, you will check res versus a custom series_valid
    print(res == res)
    
    col1    True
    col2    True
    col3    True
    col4    True
    dtype: bool
    

    If any values in your comparison is False, you can raise an error:

    if (res != series_valid).any():
        indices = (res != series_valid).index.tolist()
        raise TypeError("Some columns have incorrect type: {0}".format(indices))