Search code examples
dataframedictionarypysparkcosine-similarity

Create and update a MapType column in PySpark


I want to create a maptype column in pyspark than contains keys (string) and values (frequency of the strings) to an existing dataframe. For each row, the values will accumulate based on the occurrences of the keys.

What I want it to be

As you can see, the list of unique keys is fixed in length (6 in this case, from A to F). And the frequency of the keys in the Strings is accumulated within groups. Each group will starts out with 0 for each key. The index column is there to dictate which entry comes first as I need this to be in chronological order. I will extract and use the values in the Map_accumulated column as vectors for cosine distance calculation later. My understanding is that it is doable to extract the values and use them, similar to Python's dictionary.

So far, I have a dataframe with the first three columns and a list (fixed length) of all keys in string set up. I converted the list into a dictionary with 0 as the starting value.

existing dataframe

and

    level_list = ['A', 'B', 'C', 'D', 'E', 'F']
    level_dict = {i:0 for i in level_list}

In the actual data, level_list is very long (300+) and that's why I felt I needed to create list/dictionary with 0 starting value first before I integrate it into the pyspark dataframe.

I wish I could go into further detail of what I have tried. But I really have no idea what I'm doing. I've been trying to use ChatGPT to help with the code but I couldn't figure it out.

from pyspark.sql.functions import lit, col, create_map
from itertools import chain

    my_list = list(chain(*level_dict.items()))
    my_map = create_map(my_list).alias("map")
    df = df.withColumn("map", my_map)

The code above was AI generated but I clearly didn't specify the prompt correctly as I got this error:

TypeError: Invalid argument, not a string or column: 0 of type . For column literals, use 'lit', 'array', 'struct' or 'create_map' function.

Any help is truly, truly appreciated. I have some experience doing data analysis in R, but trying to learn Python and Spark (along with, I guess, SQL) at the same time is very confusing.


Solution

  • One option is to explode the Strings list and pivot the dataframe to have Strings as column and count as the value of the columns. Next, use a window function to do a cumulative count for all element in level_list. Then, aggregate the count values into MapType.

    # import it as alias F, Because I will use sum from the python built-in later.
    from pyspark.sql import functions as F
    
    level_list = ['A', 'B', 'C', 'D', 'E', 'F']
    w = Window.partitionBy('Group').orderBy('Index')
    
    df = (df.withColumn('Strings', F.explode('Strings'))
          .groupby('Group', 'Index')
          .pivot('Strings')
          .count()
          .fillna(0)
          .select('*', *[F.sum(x).over(w).alias(f'{x}_cum') for x in level_list]))
    

    This will result in

    +--------+-----+---+---+---+---+---+---+-----+-----+-----+-----+-----+-----+
    |   Group|Index|  A|  B|  C|  D|  E|  F|A_cum|B_cum|C_cum|D_cum|E_cum|F_cum|
    +--------+-----+---+---+---+---+---+---+-----+-----+-----+-----+-----+-----+
    |Appricot|    1|  0|  1|  1|  0|  1|  0|    0|    1|    1|    0|    1|    0|
    |Appricot|    2|  2|  1|  1|  1|  0|  0|    2|    2|    2|    1|    1|    0|
    |   Peach|    1|  1|  1|  0|  0|  0|  0|    1|    1|    0|    0|    0|    0|
    |   Peach|    2|  1|  0|  0|  1|  0|  1|    2|    1|    0|    1|    0|    1|
    |   Peach|    3|  1|  0|  0|  2|  0|  1|    3|    1|    0|    3|    0|    2|
    +--------+-----+---+---+---+---+---+---+-----+-----+-----+-----+-----+-----+
    

    If you are missing any A-F columns, you can add them by df.withColumn('X', F.lit(0)).

    To aggregate into MapType,

    # **This sum is python's function not pyspark's function.
    df = df.select('Group', 'Index', 
                   F.create_map(*sum([[F.lit(x), F.col(x)] for x in level_list], [])).alias('Map'),
                   F.create_map(*sum([[F.lit(x), F.col(f'{x}_cum')] for x in level_list], [])).alias('Map_acc'))
    

    Update

    I am not sure if this is any better in performance, but I tried without exploding.

    # This is required to map_concat on the same key.
    spark.conf.set('spark.sql.mapKeyDedupPolicy', 'LAST_WIN')
    
    def count_strings(acc, x):
        new_val = F.coalesce(acc[x] + 1, F.lit(1))
        return F.map_concat(acc, F.create_map(F.lit(x), F.lit(new_val)))
    
    df = df.withColumn('Map', F.aggregate('Strings', F.create_map().cast("map<string,int>"), count_strings))
    

    Result

    +-------+-----+---------------+--------------------------------+
    |Group  |Index|Strings        |Map                             |
    +-------+-----+---------------+--------------------------------+
    |Peach  |1    |[A, B]         |{A -> 1, B -> 1}                |
    |Peach  |2    |[A, D, F]      |{A -> 1, D -> 1, F -> 1}        |
    |Peach  |3    |[D, F, D, A]   |{D -> 2, F -> 1, A -> 1}        |
    |Apricot|1    |[B, C, E]      |{B -> 1, C -> 1, E -> 1}        |
    |Apricot|2    |[B, C, A, A, D]|{B -> 1, C -> 1, A -> 2, D -> 1}|
    +-------+-----+---------------+--------------------------------+