Search code examples
swiftmacos-mojavecreateml

Swift MLDataTable - how to remove rows?


I'm creating a MLDataTable from a .csv-file and would like to remove some rows, i.e. all rows where a specific column has a specific value - is this possible?


Solution

  • I know I'm somewhat late with my answer, but hopefully someone else will find it useful.

    You can't remove rows from a given table in place, but you can create a new table with some rows filtered out.

    Here's an example table:

    let employeesDict: [String: MLDataValueConvertible] = [
        "First Name": ["Alice", "Bob", "Charlie", "Dave", "Eva"],
        "Years of experience": [10, 1, 8, 5, 3],
        "Gender": ["female", "male", "male", "male", "female"],
    ]
    
    let employeesTable = try! MLDataTable(dictionary: employeesDict)
    

    Filtering is achieved by passing an instance of MLDataColumn<Bool> to a table's subscript operator. Apple calls it a 'row mask'. Here's a row mask for filtering out female instances built by hand:

    let maleEmployeesMaskByHand = MLDataColumn([false, true, true, true, false])
    

    Passing it as an argument to employeesTable's subscript operator yields the following table:

    let maleEmployeesTable = employeesTable[maleEmployeesMaskByHand]
    print(maleEmployeesTable)
    +----------------+----------------+---------------------+
    | Gender         | First Name     | Years of experience |
    +----------------+----------------+---------------------+
    | male           | Bob            | 1                   |
    | male           | Charlie        | 8                   |
    | male           | Dave           | 5                   |
    +----------------+----------------+---------------------+
    

    Here's another way to build the same row mask:

    let genderColumn: MLDataColumn<String> = employeesTable["Gender"]
    let maleEmployeesMask = genderColumn != "female"
    print(employeesTable[maleEmployeesMask])
    

    First the desired column is retrieved and then – thanks to operator overloading – row mask is built by applying != operator to a whole column.

    Here's a way to achieve the same in one line:

    print(employeesTable[ employeesTable["Gender"] != "female" ])
    

    A link to relevant documentation: https://developer.apple.com/documentation/createml/mldatatable/3006094-subscript