Search code examples
arraysapache-sparkpysparkapache-spark-sqlmultiple-columns

Create new columns based on frequency of array from one column


I have dataframe like this:

     column_1     column_2
    ['a','c']            1
    ['b','c']            2
['a','b','c']            1

Now I want to add 3 columns (a, b and c), based of frequency of occurrence.

Desired output:

a   b    c   column_2
1   0    1          1
0   1    1          2
1   1    1          1

Solution

  • Assuming you know the names of the columns which you will create beforehand (so, you can store the names in a list), the following approaches do it without shuffling.

    If you just need to know if array contains the value:

    • Spark 3.1+

      from pyspark.sql import functions as F
      df = spark.createDataFrame(
          [(['a','c'], 1),
           (['b','c'], 2),
           (['a','b','c'], 1)],
          ['column_1', 'column_2']
      )
      cols = ['a', 'b', 'c']
      arr_cols = F.array([F.lit(x) for x in cols])
      arr_vals = F.transform(arr_cols, lambda c: F.array_contains('column_1', c).cast('int'))
      df = df.select(
          *[F.element_at(arr_vals, i+1).alias(c) for i, c in enumerate(cols)],
          'column_2'
      )
      df.show()
      # +---+---+---+--------+
      # |  a|  b|  c|column_2|
      # +---+---+---+--------+
      # |  1|  0|  1|       1|
      # |  0|  1|  1|       2|
      # |  1|  1|  1|       1|
      # +---+---+---+--------+
      
    • Spark 2.4+

      from pyspark.sql import functions as F
      df = spark.createDataFrame(
          [(['a','c'], 1),
           (['b','c'], 2),
           (['a','b','c'], 1)],
          ['column_1', 'column_2']
      )
      cols = ['a', 'b', 'c']
      df = df.withColumn('arr_cols', F.array([F.lit(x) for x in cols]))
      arr_vals = F.expr("transform(arr_cols, c -> cast(array_contains(column_1, c) as int))")
      df = df.select(
          *[F.element_at(arr_vals, i+1).alias(c) for i, c in enumerate(cols)],
          'column_2'
      )
      

    If you need to know the count of occurrences:

    • Spark 3.1+

      from pyspark.sql import functions as F
      df = spark.createDataFrame(
          [(['a','c'], 1),
           (['b','c'], 2),
           (['a','a','b','c'], 1)],
          ['column_1', 'column_2']
      )
      cols = ['a', 'b', 'c']
      arr_cols = F.array([F.lit(x) for x in cols])
      arr_vals = F.transform(arr_cols, lambda c: F.size(F.array_remove(F.transform('column_1', lambda v: v == c), False)))
      df = df.select(
          *[F.element_at(arr_vals, i+1).alias(c) for i, c in enumerate(cols)],
          'column_2'
      )
      df.show()
      # +---+---+---+--------+
      # |  a|  b|  c|column_2|
      # +---+---+---+--------+
      # |  1|  0|  1|       1|
      # |  0|  1|  1|       2|
      # |  2|  1|  1|       1|
      # +---+---+---+--------+
      
    • Spark 2.4+

      from pyspark.sql import functions as F
      df = spark.createDataFrame(
          [(['a','c'], 1),
           (['b','c'], 2),
           (['a','a','b','c'], 1)],
          ['column_1', 'column_2']
      )
      cols = ['a', 'b', 'c']
      df = df.withColumn('arr_cols', F.array([F.lit(x) for x in cols]))
      arr_vals = F.expr("transform(arr_cols, c -> size(array_remove(transform(column_1, v -> v = c), false)))")
      df = df.select(
          *[F.element_at(arr_vals, i+1).alias(c) for i, c in enumerate(cols)],
          'column_2'
      )