Search code examples
pythonsnowflake-cloud-data-platformuser-defined-functions

Snowpark UDF with Row input type


I would like to define a Snowpark UDF with input type snowflake.snowpark.Row. The reason for this is that I would like to mimic the pandas.apply approach where I can define my business logic in some class, and then apply the logic to each row of the Snowpark dataframe. Each column can be easily mapped to a class attribute with asDict

For example (running from the Snowflake Python worksheet):

import snowflake.snowpark as snowpark
from snowflake.snowpark.functions import udf
from snowflake.snowpark import Row
from snowflake.snowpark.types import IntegerType


from dataclasses import dataclass

@dataclass
class MyEvent:
    attribute1: str = 'dummy'
    attribute2: str = 'unknown'
    def someCalculation(self) -> int:
        return len(self.attribute1) + len(self.attribute2.strip())

def testSomeCalculation():
    inputDict = {'attribute1': 'foo',
                 'attribute2': 'baz'}
    event = MyEvent(**inputDict)
    print(event.someCalculation())


def main(session: snowpark.Session):

    some_logic = udf(lambda row: MyEvent(**(row.asDict())).someCalculation()
              , return_type=IntegerType()
              , input_types=[Row])

However, when I try to use snowpark.Row as input type, I get an unsupported data type:

File "snowflake/snowpark/_internal/udf_utils.py", line 972, in create_python_udf_or_sp
    input_sql_types = [convert_sp_to_sf_type(arg.datatype) for arg in input_args]
  File "snowflake/snowpark/_internal/udf_utils.py", line 972, in <listcomp>
    input_sql_types = [convert_sp_to_sf_type(arg.datatype) for arg in input_args]
  File "snowflake/snowpark/_internal/type_utils.py", line 195, in convert_sp_to_sf_type
    raise TypeError(f"Unsupported data type: {datatype.__class__.__name__}")
TypeError: Unsupported data type: type

I see that all the UDF examples use basic types from snowpark.types. Is there any fundamental reason why the input type cannot be a snowpark.Row ?

I know I could list explicitly all MyEvent attributes in input_type=[], but that is going to be error prone and defeating the purpose of designing my code around a class representing my business object.


Solution

  • Solution tested in Snowflake Python worksheet, based on the suggestion by @felipe-hoffa above

    import snowflake.snowpark as snowpark
    from snowflake.snowpark.functions import col, udf
    from snowflake.snowpark.types import IntegerType, VariantType
        
    from dataclasses import dataclass
    import json
    
    def main(session: snowpark.Session):
        
        @dataclass
        class MyEvent:
            attribute1: str = 'dummy'
            attribute2: str = 'unknown'
            def someText(self) -> str:
                return f"someText {len(self.attribute1)} : {self.attribute1=}, {self.attribute2=}"
    
        def wrap_some_text(x) -> str:
            return MyEvent(**json.loads(x)).someText()
        
        my_event_get_text = udf(lambda x: wrap_some_text(x), return_type=VariantType(), input_types=[VariantType()])
    
        df = session.create_dataframe(['{"attribute1":"value1", "attribute2":"value20"}']).to_df("col1")
        df = df.select(col("col1"),
                       my_event_get_text(col("col1").astype("variant")).as_("my_event_get_text")
                      ).show()
        return df