Search code examples
pythonpyspark

Pyspark - find the index of first positive number in an array column


I've got an array column in a pyspark dataframe, and I want to find the index of the first positive number in each array. The data looks like this:

id arr
Cell 1 -1, -1, -1, -1
Cell 2 -1, -1, 5, -1
Cell 3 -1, 3, -1, -1

I want to get an output similar to this:

id arr first_positive_element_index
Cell 1 -1, -1, -1, -1 null
Cell 2 -1, -1, 5, -1 2
Cell 3 -1, 3, -1, -1 1

I can do this with a UDF, but the data is quite large making this approach extremely slow. I will prefer if there's a much better way around this without the use of a UDF.

Note: all non-positive numbers are -1


Solution

  • You can use the expr with array_position:

    df_pos = df.select(
        'id', 'arr',
        func.explode('arr').alias('arr_explode_value')
    ).filter(
        func.col('arr_explode_value')>=0
    ).withColumn(
        'pos', func.expr('array_position(arr, arr_explode_value)')-1
    ).groupBy(
        'id'
    ).agg(
        func.min('pos').alias('pos')
    )
    df_pos.show(10, False)
    +------+---+
    |id    |pos|
    +------+---+
    |Cell 2|2  |
    |Cell 3|1  |
    +------+---+
    

    You can create a dataframe to

    1. Explode the array
    2. Filter out the positive value
    3. Find the smallest index

    The rest is to join the reference table back to the dataframe.

    df.select('id', 'arr').join(df_pos.select('id', 'pos'), on=['id'], how='left')
    

    Edit 1:

    If you don't want to use explode because of the long array, you can use the transform and array_position:

    df.select(
        'id', 'arr',
        func.transform(func.col('arr'), lambda value: func.when(value>=0, 1).otherwise(0)).alias('transformed_arr')
    ).withColumn(
        'pos', func.array_position('transformed_arr', 1)-1
    ).show(
        10, False
    )
    +------+----------------+---------------+---+
    |id    |arr             |transformed_arr|pos|
    +------+----------------+---------------+---+
    |Cell 1|[-1, -1, -1, -1]|[0, 0, 0, 0]   |-1 |
    |Cell 2|[-1, -1, 5, -1] |[0, 0, 1, 0]   |2  |
    |Cell 3|[-1, 3, -1, -1] |[0, 1, 0, 0]   |1  |
    +------+----------------+---------------+---+
    

    Since the arr column is array type, you can use transform to apply function on the element.