Search code examples
pythonsetpysparkdata-conversionapache-spark-sql

How to convert a Spark Dataframe column from vector to a set?


I need to process a dataset to identify frequent itemsets. So the input column must be a vector. The original column is a string with the items separated by comma, so i did the following:

functions.split(out_1['skills'], ',')

The problem is the, for some rows, I have duplicated values in the skills and this is causing an error when trying to identify the frequent itemsets.

I wanted to convert the vector to a set to remove the duplicated elements. Something like this:

functions.to_set(functions.split(out_1['skills'], ','))

But I could not find a function to convert a column from vector to set, i.e., there is no to_set function.

How can I accomplish what I want, i.e., remove the duplicated elements from the vector?


Solution

  • You can convert the set function in python to a udf using functions.udf(set) and then apply it to the array column:

    df.show()
    +-------+
    | skills|
    +-------+
    |a,a,b,c|
    |  a,b,c|
    |c,d,e,e|
    +-------+
    
    import pyspark.sql.functions as F
    df.withColumn("unique_skills", F.udf(set)(F.split(df.skills, ","))).show()
    +-------+-------------+
    | skills|unique_skills|
    +-------+-------------+
    |a,a,b,c|    [a, b, c]|
    |  a,b,c|    [a, b, c]|
    |c,d,e,e|    [c, d, e]|
    +-------+-------------+