Search code examples
scalapandasapache-sparkdatasetapache-spark-ml

spark dataset filter column with conditions like pandas


I am new to Spark/Scala. i do not know how to use spark dataset to filter columns like pandas.loc.

pandas code:

data_fact = pd.read_excel(path, sheetname=sheetname_factor)
//drop some columns which have too many null value
data_fact_v1=data_fact.loc[:,((data_fact>0).sum()>len(data_fact) *0.7)]

Your help is very much appreciated!


Solution

  • I would use a RDD for this because the API is more flexible. In the following code, I map each row to a list of tuple2 with the column name associated to 0 if the value of the field is null, 1 otherwise. Then I flatten everything and count the number of non null values per column with reduceByKey. I finally drop the columns that do not match your requirements in the original dataframe.

    var data = ...
    val cols = data.columns
    val total = data.count
    
    val nullMap = data.rdd
        .flatMap{row => cols.filter(col => row.getAs(col) == null).map(col => (col, 1) ) }
        .reduceByKey(_+_)
        .collectAsMap
    
    for(col <- cols) 
        if(nullMap.getOrElse(col, 0).toDouble / total < 0.7)
            data = data.drop(col)
    

    Edit other method: to avoid flattening the data, you can use the aggregate function

    def combine(map1 : Map[String, Int], map2 : Map[String, Int]) = 
        map1.keySet
            .union(map2.keySet)
            .map(k => (k, map1.getOrElse(k, 0)+map2.getOrElse(k, 0)))
            .toMap
    
    val nullMap = data.rdd.aggregate(Map[String, Int]())(
         (map, row)=> combine(map, cols.filter(col => row.getAs(col) == null).map(col => (col, 1)).toMap),
         combine)
    

    And then the same

    for(col <- cols) 
        if(nullMap.getOrElse(col, 0).toDouble / total >= 0.3)
            data = data.drop(col)
    

    Or

    val valid_columns = cols
        .filter(col => nullMap.getOrElse(col, 0).toDouble / total >= 0.3)
    data = data.drop(valid_columns : _*)