Search code examples
python-3.xdictionarypysparkrowrdd

How to iterate a RDD and remove the field if it exist in a list using PySpark


I have a list which contains a couple of string values/field names. I also have a Spark RDD, I'd like to iterate the rdd and remove any field name that exists in the list. For example:

field_list = ["name_1", "name_2"]

RDD looks like this:

[Row(field_1=1, field_2=Row(field_3=[Row(field_4=[Row(name_1='apple', name_2='banana', name_3='F'), Row(name_1='tomato', name_2='eggplant', name_3='F')])]))]

I'm not very familiar with RDD, I understand that I can use map() to perform iteration, but how can I add the conditions, if it finds "name_1" or "name_2" which exists in the field_list, then remove the value and the field, so the expected result is a new RDD looks like:

[Row(field_1=1, field_2=Row(field_3=[Row(field_4=[Row(name_3='F'), Row(name_3='F')])]))]

Solution

  • You could recreate the whole structure, but without fields which you don't need. I'm not sure, maybe there's a better method, but looking at the Row documentation we see that it's limited on methods.

    Inputs:

    from pyspark.sql import Row
    rdd = spark.sparkContext.parallelize([
        Row(field_1=1, field_2=Row(field_3=[Row(field_4=[Row(name_1='apple', name_2='banana', name_3='F'), Row(name_1='tomato', name_2='eggplant', name_3='F')])]))
    ])
    print(rdd.collect())
    # [Row(field_1=1, field_2=Row(field_3=[Row(field_4=[Row(name_1='apple', name_2='banana', name_3='F'), Row(name_1='tomato', name_2='eggplant', name_3='F')])]))]
    
    field_list = ["name_1", "name_2"]
    

    Script:

    F4 = Row('field_4')
    F3 = Row('field_3')
    F2 = Row('field_1', 'field_2')
    def transform(row):
        f3 = []
        for x in row['field_2']['field_3']:
            f4 = []
            for y in x['field_4']:
                Names = Row(*(set(y.asDict()) - set(field_list)))
                f4.append(Names(*[y[n] for n in Names]))
            f3.append(F4(f4))
        return F2(row['field_1'], F3(f3))
    
    rdd = rdd.map(transform)
    
    print(rdd.collect())
    # [Row(field_1=1, field_2=Row(field_3=[Row(field_4=[Row(name_3='F'), Row(name_3='F')])]))]