Search code examples
dataframeapache-sparkpysparkapache-spark-sqlgrouping

Grouping alternative items with PySpark


The sample of the dataset I am working on:

# Creating the DataFrame
test =sqlContext.createDataFrame([(1,2),(2,1),
                               (1,3),(2,3),
                               (3,2),(3,1),
                               (4,5),(5,4)],
                               ['cod_item','alter_cod'])

enter image description here

And it looks like this after grouping the equivalent items in lists:

test.createOrReplaceTempView("teste")

teste = spark.sql("""select cod_item,
                  collect_list(alter_cod) as alternative_item 
                  from teste
                  group by cod_item""")

enter image description here

In the first column, I have certain items and in the second column, I have items that are equivalent. I would like, for each list, to have only one item that represents it.

I would like the final dataframe to look like this:

enter image description here

or

enter image description here

Where the items on the right are the items representing their respective equivalent items.


Solution

  • After collect_list, you should filter out rows where any alter_cod is bigger than cod_item. This method would work on strings too.

    test = (test
        .groupBy('cod_item')
        .agg(F.collect_list('alter_cod').alias('alter_cod'))
        .filter(F.forall('alter_cod', lambda x: x > F.col('cod_item')))
    )
    
    test.show()
    # +--------+---------+
    # |cod_item|alter_cod|
    # +--------+---------+
    # |       1|   [2, 3]|
    # |       4|      [5]|
    # +--------+---------+
    

    Or add one line to your SQL:

    select cod_item,
           collect_list(alter_cod) as alternative_item 
    from teste
    group by cod_item
    having forall(alternative_item, x -> x > cod_item)