I have an arbitrary number of arrays of equal length in a PySpark DataFrame. I need to coalesce these, element by element, into a single list. The problem with coalesce is that it doesn't work by element, but rather selects the entire first non-null array. Any suggestions for how to accomplish this would be appreciated. Please see the test case below for an example of expected input and output:
def test_coalesce_elements():
"""
Test array coalescing on a per-element basis
"""
from pyspark.sql import SparkSession
import pyspark.sql.types as t
import pyspark.sql.functions as f
spark = SparkSession.builder.getOrCreate()
data = [
{
"a": [None, 1, None, None],
"b": [2, 3, None, None],
"c": [5, 6, 7, None],
}
]
schema = t.StructType([
t.StructField('a', t.ArrayType(t.IntegerType())),
t.StructField('b', t.ArrayType(t.IntegerType())),
t.StructField('c', t.ArrayType(t.IntegerType())),
])
df = spark.createDataFrame(data, schema)
# Inspect schema
df.printSchema()
# root
# | -- a: array(nullable=true)
# | | -- element: integer(containsNull=true)
# | -- b: array(nullable=true)
# | | -- element: integer(containsNull=true)
# | -- c: array(nullable=true)
# | | -- element: integer(containsNull=true)
# Inspect df values
df.show(truncate=False)
# +---------------------+------------------+---------------+
# |a |b |c |
# +---------------------+------------------+---------------+
# |[null, 1, null, null]|[2, 3, null, null]|[5, 6, 7, null]|
# +---------------------+------------------+---------------+
# This obviously does not work, but hopefully provides the general idea
# Remember: this will need to work with an arbitrary and dynamic set of columns
input_cols = ['a', 'b', 'c']
df = df.withColumn('d', f.coalesce(*[f.col(i) for i in input_cols]))
# This is the expected output I would like to see for the given inputs
assert df.collect()[0]['d'] == [2, 1, 7, None]
Thanks in advance for any ideas!
Thanks to Derek and Tushar for their responses! I was able to use them as a basis to solve the issue without a UDF, join, or explode.
Generally speaking, joins are unfavorable due to being computationally and memory expensive, UDFs can be computationally intensive, and explode can be memory intensive. Fortunately we can avoid all of these using transform:
def add_coalesced_list_by_elements_col(
df: DataFrame,
cols: List[Union[Column, str]],
col_name: str,
) -> DataFrame:
"""
Adds a new column representing a list that is collected by element from the
input set. Please note that all provided this does not check that all provided
columns are of equal length.
Args:
df: Input DataFrame to add column to
cols: List of columns to collect by element. All columns should be of equal length.
col_name: The name of the new column
Returns:
DataFrame with result added as a new column.
"""
return (
df
.withColumn(
col_name,
f.transform(
# Doesn't matter which col, we don't use this val
cols[0],
# We use i + 1 because sql array index starts at 1, while transform index starts at 0
lambda _, i: f.coalesce(*[f.element_at(c, i + 1) for c in cols]))
)
)
def test_collect_list_elements():
from typing import List
import pyspark.sql.functions as f
import pyspark.sql.types as t
from pyspark.sql import SparkSession, DataFrame, Column, Window
# Arrange
spark = SparkSession.builder.getOrCreate()
data = [
{
"id": 1,
"a": [None, 1, None, None],
"b": [2, 3, None, None],
"c": [5, 6, 7, None],
}
]
schema = t.StructType(
[
t.StructField("id", t.IntegerType()),
t.StructField("a", t.ArrayType(t.IntegerType())),
t.StructField("b", t.ArrayType(t.IntegerType())),
t.StructField("c", t.ArrayType(t.IntegerType())),
]
)
df = spark.createDataFrame(data, schema)
# Act
df = add_coalesced_list_by_elements_col(df=df, cols=["a", "b", "c"], col_name="d")
# Assert new col is correct output
assert df.collect()[0]["d"] == [2, 1, 7, None]
# Assert all the other cols are not affected
assert df.collect()[0]["a"] == [None, 1, None, None]
assert df.collect()[0]["b"] == [2, 3, None, None]
assert df.collect()[0]["c"] == [5, 6, 7, None]