Search code examples
pandaspysparkuser-defined-functions

Pyspark variance across columns using Pandas udf


New to pyspark and struggling with simple things.I want to define a Pandas UDF that take multiple columns as input and calculate the variance on rows for those input columns.

But I get errors.

I tried to specify input as a list. But that didn't work. How could this be solved?

This didn't work for me either. Calculate variance across columns in pyspark

test = pd.DataFrame({'id': ['a', 'b', 'c', 'd', 'e'],
'feat1': [3,4,5,6, 7],
'feat2': [6, 9, 2, 4, 5]
})

test['var_pd'] = test[['feat1', 'feat2']].var(axis = 1)

test = spark.createDataFrame(test)
from typing import List
@pandas_udf(returnType=DoubleType())
def variance_udf(*cols:List[pd.Series]) -> pd.Series:
    return pd.Series(cols).var(axis=1)

test = test.withColumn("variance_udf", variance_udf('feat1', 'feat2'))


Solution

  • df = spark.createDataFrame(
        [('a', 3, 6), ('b', 4, 9)],
        schema=['id', 'feat1', 'feat2']
    )
    
    df.printSchema()
    df.show(5, False)
    root
     |-- id: string (nullable = true)
     |-- feat1: long (nullable = true)
     |-- feat2: long (nullable = true)
    
    +---+-----+-----+
    |id |feat1|feat2|
    +---+-----+-----+
    |a  |3    |6    |
    |b  |4    |9    |
    +---+-----+-----+
    

    You can pass all the columns into the UDF by using non keyword arguments:

    @func.udf(returnType=FloatType())
    def variance_udf(*cols):
        return float(variance([int(x) for x in cols]))
    
    df.withColumn(
        'variance_udf',
        variance_udf(*[df[col] for col in df.columns[1:]])
    ).show(
        10, False
    )
    +---+-----+-----+------------+
    |id |feat1|feat2|variance_udf|
    +---+-----+-----+------------+
    |a  |3    |6    |4.5         |
    |b  |4    |9    |12.5        |
    +---+-----+-----+------------+
    

    Or you can use the RDD operations:

    from statistics import variance    
    df.rdd.map(lambda row: (row[0], float(variance([int(x) for x in row[1:]])))).toDF(['id', 'var']).show(5, False)
    +---+----+
    |id |var |
    +---+----+
    |a  |4.5 |
    |b  |12.5|
    +---+----+