Search code examples
pythonpysparkdatabricksuser-defined-functionspandas-udf

Pyspark Pandas-Vectorized UDFs


I am trying to convert this udf into this pandas udf, in order to avoid creating two pandas udfs.

Convert this:

@udf("string")
  def splitEmailUDF(email: str, position: int) -> str:
  return email.split("@")[position]

into this in one pandas udf --- position ??? Datatype or something else!

from pyspark.sql.functions import pandas_udf, PandasUDFType
    
@pandas_udf("string")
def splitEmailUDFVec(email: pd.Series, position: ???????) -> pd.Series:
  return email.str.split("@").str[position]

Of course I can always create two pandas_udfs:

from pyspark.sql.functions import pandas_udf
        
@pandas_udf("string")
def splitFirstNameUDFVec(email: pd.Series) -> pd.Series:
  return email.str.split("@").str[0]
        
@pandas_udf("string")
def splitDomainUDFVec(email: pd.Series) -> pd.Series:
  return email.str.split("@").str[1]

Any help will be appreciated!


Solution

  • Setup

    df.show()
    
    +------------+
    |       email|
    +------------+
    | [email protected]|
    |[email protected]|
    +------------+
    

    Define a wrapper function which takes email and pos as arguments and returns the underlying pandas udf function

    def split(email, pos):
        @F.pandas_udf('string')
        def _split(email: pd.Series) -> pd.Series:
            return email.str.split('@').str[pos]
        
        return _split(email)
    
    df = df.withColumn('firstname', split('email', 1))
    

    Result

    df.show()
    
    +------------+--------+
    |       email|  domain|
    +------------+--------+
    | [email protected]| bar.com|
    |[email protected]|spam.com|
    +------------+--------+
    

    Alternatively a better/efficient approach is to use regex extraction if your goal is to only split the email address into its name and domain components.

    f = lambda n: F.regexp_extract('email', '(.*)@(.*)', n)
    df = df.select('*', f(1).alias('firstname'), f(2).alias('domain'))
    

    Result

    df.show()
    
    +------------+---------+--------+
    |       email|firstname|  domain|
    +------------+---------+--------+
    | [email protected]|      foo| bar.com|
    |[email protected]|      baz|spam.com|
    +------------+---------+--------+