I'm working with a pyspark DataFrame that contains multiple levels of nested arrays of structs. My goal is to add an array's hash column + record's top level hash column to each nested array. To achieve that for all nested arrays I need to use recursion since I do not know how nested the array can be.
So for this example schema
schema = StructType([
StructField("name", StringType()),
StructField("experience", ArrayType(StructType([
StructField("role", StringType()),
StructField("duration", StringType()),
StructField("company", ArrayType(StructType([
StructField("company_name", StringType()),
StructField("location", StringType())
])))
])))
])
The desired output schema would look like this:
hashed_schema = StructType([
StructField("name", StringType()),
StructField("experience", ArrayType(StructType([
StructField("role", StringType()),
StructField("duration", StringType()),
StructField("experience_hash", StringType()), # Added hash for the experience collection
StructField("company", ArrayType(StructType([
StructField("company_name", StringType()),
StructField("location", StringType()),
StructField("company_hash", StringType()) # Added hash for the company collection
])))
]))),
StructField("employee_hash", StringType()), # Added hash for the entire record
])
I have tried to write a code with recursion that would iterate trough each nested array and hash its columns. While it seems to work for 1st level nested arrays, the recursion part does not work, I get an error that the recursion is too deep.
def hash_for_level(level_path):
return md5(concat_ws("_", *[lit(elem) for elem in level_path]))
def add_hash_columns(df, level_path, current_struct, root_hash_col=None):
# If this is the root level, create the root hash
if not level_path and root_hash_col is None:
root_hash_col = 'employee_hash'
df = df.withColumn(root_hash_col, hash_for_level(['employee']))
# Traverse the current structure and add hash columns
for field in current_struct.fields:
new_level_path = level_path + [field.name]
# If the field is an array of structs, add a hash for each element in the array
if isinstance(field.dataType, ArrayType):
nested_struct = field.dataType.elementType
hash_expr = transform(
col('.'.join(level_path + [field.name])),
lambda x: x.withField(new_level_path[-1] + '_hash', hash_for_level(new_level_path))
.withField(root_hash_col, col(root_hash_col)) # Include the root hash
)
# Add the hash column to the array elements
df = df.withColumn('.'.join(level_path + [field.name]), hash_expr)
# Recursion call to apply the same logic for nested arrays
df = add_hash_columns(df, new_level_path, nested_struct, root_hash_col)
# Add a hash column at the current level
if level_path:
#print("Level path:", level_path)
hash_col_name = '_'.join(level_path) + '_hash'
df = df.withColumn(hash_col_name, hash_for_level(level_path))
if root_hash_col:
# Ensure the root hash is included at each struct level
df = df.withColumn(root_hash_col, col(root_hash_col))
return df
df = spark.createDataFrame([], schema)
df = add_hash_columns(df, [], df.schema)
df
The approach in the question using transform on the arrays worked for me. But instead of adding the hash columns with withColumn I generate a (large) set of column expressions including the hash columns and use this list in a select call.
from pyspark.sql import functions as F
from pyspark.sql import types as T
# handling the first level
def add_hashes(df):
hashcols = []
for field in df.schema:
hashcols.append(df[field.name])
if isinstance(field.dataType, T.ArrayType):
yield transform_array(field.name, df[field.name], df.schema[field.name].dataType.elementType)
else:
yield df[field.name]
yield F.hash(*hashcols).alias('hash')
# special handling for adding the hash column to arrays
def transform_array(colname: str, col: F.Column, elementOfStruct: T.StructType):
fields = elementOfStruct.fields
hashcols = []
def process_struct_elements(x: F.Column):
for field in fields:
name = field.name
dataType = field.dataType
hashcols.append(x[name])
if isinstance(dataType, T.ArrayType):
yield transform_array(name, x[name], dataType.elementType)
else:
yield x[name].alias(name)
yield F.hash(*hashcols).alias('hash')
return F.transform(col, lambda x: F.struct(*process_struct_elements(x))).alias(colname)
# starting the process
cols=list(add_hashes(df))
df.select(cols).printSchema()
df.select(cols).show(truncate=True)
Output:
root
|-- experience: array (nullable = true)
| |-- element: struct (containsNull = false)
| | |-- company: array (nullable = true)
| | | |-- element: struct (containsNull = false)
| | | | |-- company_name: string (nullable = true)
| | | | |-- location: string (nullable = true)
| | | | |-- hash: integer (nullable = false)
| | |-- duration: string (nullable = true)
| | |-- role: string (nullable = true)
| | |-- hash: integer (nullable = false)
|-- name: string (nullable = true)
|-- hash: integer (nullable = false)
+--------------------+----+-----------+
| experience|name| hash|
+--------------------+----+-----------+
|[{[{cn111, loc111...| n1|-1931528302|
|[{[{cn211, loc211...| n2| 312789015|
+--------------------+----+-----------+
Remarks: