Search code examples
pythonarrayspysparkazure-databricks

Pyspark: Subset Array based on other column value


I use Pyspark in Azure Databricks to transform data before sending it to a sink. In this sink any array must at most have a length of 100. In my data I have an array that is always length 300 an a field specifying how many values of these are relevant (n_relevant).

n_relevant values might be:

  • below 100 -> then I want to keep all values
  • between 100 and 300 -> then I want to subsample based on modulo
  • above 300 -> then I want to subsample modulo 3

E.g.:

array: [1,2,3,4,5,...300]
n_relevant: 4
desired outcome: [1,2,3,4]

array: [1,2,3,4,5,...300]
n_relevant: 200
desired outcome: [1,3,5,...199]

array: [1,2,3,4,5,...300]
n_relevant: 300
desired outcome: [1,4,7,...298]

array: [1,2,3,4,5,...300]
n_relevant: 800
desired outcome: [1,4,7,...298]

This little program reflects the desired behavior:

from math import ceil

def subsample(array:list,n_relevant:int)->list:
    if n_relevant<100:
        return [x for i,x in enumerate(array) if i<n_relevant]
    if 100<=n_relevant<300:
        mod=ceil(n_relevant/100)
        return [x for i,x in enumerate(array) if i%mod==0 and i<n_relevant]
    else:
        return [x for i,x in enumerate(array) if i%3==0]
        
n_relevant=<choose n>

t1=[i for i in range(300)]

subsample(t1,n_relevant)

What I have tried:

transforms to set undesired values to 0 and remove those with array_remove could subset with a specific modulo BUT cannot adopt to n_relevant. Specifically you cannot hand a parameter to the lambda function and you cannot dynamically change the function.


Solution

  • You can filter by index as follows

    from pyspark.sql.types import StructField, StructType, IntegerType, ArrayType
    
    df = spark.createDataFrame(
        [[list(range(300)), 4], [list(range(300)), 200], [list(range(300)), 300], [list(range(300)), 800]],
        schema=StructType(
            [
                StructField("array", ArrayType(IntegerType())),
                StructField("n_relevant", IntegerType()),
            ]
        ),
    )
    
    df = df.withColumn(
        "result",
        F.when(F.col("n_relevant") <= 100, F.slice("array", 1, F.col("n_relevant")))
        .when(
            F.col("n_relevant") <= 200,
            F.filter(
                F.slice("array", 1, F.col("n_relevant")), lambda _, index: index % 2 == 0
            ),
        )
        .otherwise(
            F.filter(
                F.slice("array", 1, F.col("n_relevant")), lambda elem, index: index % 3 == 0
            )
        ),
    )
    display(df)