Search code examples
pythonregexapache-sparkpysparksplit

conditional split based on list of column


I have a dataframe having 2 column - "id" (int) and "values" (list of struct). I need to split on name. I have a list of column names as delimiter. I need to check the occurence of column names from the list, if one of the column name is present , then split the dataframe.

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, explode
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, ArrayType

value_schema = ArrayType(
    StructType([
        StructField("name", StringType(), True),
        StructField("location", StringType(), True)
    ])
)

data = [
    (1, [
        {"name": "col1_US", "location": "usa"},
        {"name": "col2_name_plex", "location": "usa"},
        {"name": "col4_false", "location": "usa"},
        {"name": "col3_name_is_fantasy", "location": "usa"}
    ])
]


df = spark.createDataFrame(data, ["id", "values"])

df = df.withColumn("values", explode(col("values")).alias("values"))
df = df.select(col("id"),col("values.name").alias("name"))
df.display()


col_names = ["col1","col2_name","col3_name_is","col4"]

for c in col_names:
    #if (df["name"].contains(c)):  # this is not working
    split_data = split(df["name"], f'{c}_')
    df = df.withColumns({
        "new_name": lit(c),
        "new_value": split_data.getItem(1)
        })
df.display()

Data after cleanup:

id  name
1   col1_US
1   col2_name_plex
1   col4_false_val
1   col3_name_is_fantasy

Final data from above script:

# returning unexpected data

id    name                new_name  new_value
1   col1_US               col4      null
1   col2_name_plex        col4      null
1   col4_false_val        col4      false
1   col3_name_is_fantasy  col4      null

Expected Result:

id    name                new_name          new_value
1   col1_US               col1              US
1   col2_name_plex        col2_name         plex
1   col4_false_val        col4              false_val
1   col3_name_is_fantasy  col3_name_is      fantasy

Solution

  • We can use a regular expression to create field new_name from your list of desired col_names:

    col_names = ["col1","col2_name","col3_name_is","col4"]
    pattern = "|".join(col_names)
    
    df = df.withColumn("new_name", regexp_extract("name", pattern, 0))
    
    +---+--------------------+------------+
    | id|                name|    new_name|
    +---+--------------------+------------+
    |  1|             col1_US|        col1|
    |  1|      col2_name_plex|   col2_name|
    |  1|          col4_false|        col4|
    |  1|col3_name_is_fantasy|col3_name_is|
    +---+--------------------+------------+
    

    Then we can split using another regular expression where we pass new_name as the pattern we want to match, and only get alphanumeric symbols after that (so that we don't pull in "_").

    df.withColumn(
        "new_value", 
        expr("regexp_replace(split(name, new_name)[1], '[^a-zA-Z0-9]+', '')")
    )
    
    +---+--------------------+------------+---------+
    | id|                name|    new_name|new_value|
    +---+--------------------+------------+---------+
    |  1|             col1_US|        col1|       US|
    |  1|      col2_name_plex|   col2_name|     plex|
    |  1|          col4_false|        col4|    false|
    |  1|col3_name_is_fantasy|col3_name_is|  fantasy|
    +---+--------------------+------------+---------+