I have dataframe like this:
column_1 column_2
['a','c'] 1
['b','c'] 2
['a','b','c'] 1
Now I want to add 3 columns (a, b and c), based of frequency of occurrence.
Desired output:
a b c column_2
1 0 1 1
0 1 1 2
1 1 1 1
Assuming you know the names of the columns which you will create beforehand (so, you can store the names in a list), the following approaches do it without shuffling.
If you just need to know if array contains the value:
Spark 3.1+
from pyspark.sql import functions as F
df = spark.createDataFrame(
[(['a','c'], 1),
(['b','c'], 2),
(['a','b','c'], 1)],
['column_1', 'column_2']
)
cols = ['a', 'b', 'c']
arr_cols = F.array([F.lit(x) for x in cols])
arr_vals = F.transform(arr_cols, lambda c: F.array_contains('column_1', c).cast('int'))
df = df.select(
*[F.element_at(arr_vals, i+1).alias(c) for i, c in enumerate(cols)],
'column_2'
)
df.show()
# +---+---+---+--------+
# | a| b| c|column_2|
# +---+---+---+--------+
# | 1| 0| 1| 1|
# | 0| 1| 1| 2|
# | 1| 1| 1| 1|
# +---+---+---+--------+
Spark 2.4+
from pyspark.sql import functions as F
df = spark.createDataFrame(
[(['a','c'], 1),
(['b','c'], 2),
(['a','b','c'], 1)],
['column_1', 'column_2']
)
cols = ['a', 'b', 'c']
df = df.withColumn('arr_cols', F.array([F.lit(x) for x in cols]))
arr_vals = F.expr("transform(arr_cols, c -> cast(array_contains(column_1, c) as int))")
df = df.select(
*[F.element_at(arr_vals, i+1).alias(c) for i, c in enumerate(cols)],
'column_2'
)
If you need to know the count of occurrences:
Spark 3.1+
from pyspark.sql import functions as F
df = spark.createDataFrame(
[(['a','c'], 1),
(['b','c'], 2),
(['a','a','b','c'], 1)],
['column_1', 'column_2']
)
cols = ['a', 'b', 'c']
arr_cols = F.array([F.lit(x) for x in cols])
arr_vals = F.transform(arr_cols, lambda c: F.size(F.array_remove(F.transform('column_1', lambda v: v == c), False)))
df = df.select(
*[F.element_at(arr_vals, i+1).alias(c) for i, c in enumerate(cols)],
'column_2'
)
df.show()
# +---+---+---+--------+
# | a| b| c|column_2|
# +---+---+---+--------+
# | 1| 0| 1| 1|
# | 0| 1| 1| 2|
# | 2| 1| 1| 1|
# +---+---+---+--------+
Spark 2.4+
from pyspark.sql import functions as F
df = spark.createDataFrame(
[(['a','c'], 1),
(['b','c'], 2),
(['a','a','b','c'], 1)],
['column_1', 'column_2']
)
cols = ['a', 'b', 'c']
df = df.withColumn('arr_cols', F.array([F.lit(x) for x in cols]))
arr_vals = F.expr("transform(arr_cols, c -> size(array_remove(transform(column_1, v -> v = c), false)))")
df = df.select(
*[F.element_at(arr_vals, i+1).alias(c) for i, c in enumerate(cols)],
'column_2'
)