Search code examples
apache-sparkpyspark

Filter by maptype value in pyspark dataframe


Spark dataframe:

Text_col             Maptype_col
what is SO           {3:1, 5:1, 1:1}
what is spark        {3:2, 5:1}

I want to filter (remove) rows where at least one entry in the Maptype_col has a value greater than 1.

I have written the following code, but the value in click_filter is null for each row:

@udf(returnType=BooleanType())
def filter_map(col_map):
    retval = 0
    for k in col_map:
        if col_map[k] > 1: retval = 1
    return retval

newudf = origudf.withColumn("filtered_map"),  filter_map(F.col("Maptype_col"))

The output is:

Text_col             Maptype_col       filtered_map
what is SO           {3:1, 5:1, 1:1}   null
what is spark        {3:2, 5:1}        null

Expected output:

Text_col             Maptype_col       filtered_map
what is SO           {3:1, 5:1, 1:1}   0
what is spark        {3:2, 5:1}        1

Solution

  • You can use map_filter function, like below.

    df
    .withColumn(
        "filtered_map", 
        expr("size(map_filter(Maptype_col, (k, v) -> v > 1 ))")
    )
    
    +-------------+------------------------+------------+
    |Text_col     |Maptype_col             |filtered_map|
    +-------------+------------------------+------------+
    |what is S0   |{3 -> 1, 5 -> 1, 1 -> 1}|0           |
    |what is spark|{3 -> 2, 5 -> 1}        |1           |
    +-------------+------------------------+------------+