Search code examples
pysparkuser-defined-functionspickleazure-databricksazure-machine-learning-service

Databricks UDF calling an external web service cannot be serialised (PicklingError)


I am using Databricks and have a column in a dataframe that I need to update for every record with an external web service call. In this case it is using the Azure Machine Learning Service SDK and does a service call. This code works fine when not run as a UDF in spark (ie. just python) however it throws a serialization error when I try to call it as a UDF. The same happens if I use a lambda and a map with an rdd.

The model uses fastText and can be invoked fine from Postman or python via a normal http call or using the WebService SDK from AMLS - it's just when it is a UDF that it fails with this message:

TypeError: can't pickle _thread._local objects

The only workaround I can think of is to loop through each record in the dataframe sequentially and update the record with a call, however this is not very efficient. I don't know if this is a spark error or because the service is loading a fasttext model. When I use the UDF and mock a return value it works though.

Error at bottom...

from azureml.core.webservice import Webservice, AciWebservice
from azureml.core import Workspace

def predictModelValue2(summary, modelName, modelLabel):  
    raw_data = '[{"label": "' + modelLabel + '", "model": "' + modelName + '", "as_full_account": "' + summary + '"}]'
    prediction = service.run(raw_data)
    return prediction

from pyspark.sql.types import FloatType
from pyspark.sql.functions import udf

predictModelValueUDF = udf(predictModelValue2)

DVIRCRAMFItemsDFScored1 = DVIRCRAMFItemsDF.withColumn("Result", predictModelValueUDF("Summary", "ModelName", "ModelLabel"))

TypeError: can't pickle _thread._local objects

During handling of the above exception, another exception occurred:

PicklingError Traceback (most recent call last) in ----> 2 x = df.withColumn("Result", predictModelValueUDF("Summary", "ModelName", "ModelLabel"))

/databricks/spark/python/pyspark/sql/udf.py in wrapper(*args) 194 @functools.wraps(self.func, assigned=assignments) 195 def wrapper(*args): --> 196 return self(*args) 197 198 wrapper.name = self._name

/databricks/spark/python/pyspark/sql/udf.py in call(self, *cols) 172 173 def call(self, *cols): --> 174 judf = self._judf 175 sc = SparkContext._active_spark_context 176 return Column(judf.apply(_to_seq(sc, cols, _to_java_column)))

/databricks/spark/python/pyspark/sql/udf.py in _judf(self) 156 # and should have a minimal performance impact. 157 if self._judf_placeholder is None: --> 158 self._judf_placeholder = self._create_judf() 159 return self._judf_placeholder 160

/databricks/spark/python/pyspark/sql/udf.py in _create_judf(self) 165 sc = spark.sparkContext 166 --> 167 wrapped_func = _wrap_function(sc, self.func, self.returnType) 168 jdt = spark._jsparkSession.parseDataType(self.returnType.json()) 169 judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(

/databricks/spark/python/pyspark/sql/udf.py in _wrap_function(sc, func, returnType) 33 def _wrap_function(sc, func, returnType): 34 command = (func, returnType) ---> 35 pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command) 36 return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec, 37 sc.pythonVer, broadcast_vars, sc._javaAccumulator)

/databricks/spark/python/pyspark/rdd.py in _prepare_for_python_RDD(sc, command) 2461 # the serialized command will be compressed by broadcast 2462 ser = CloudPickleSerializer() -> 2463 pickled_command = ser.dumps(command) 2464 if len(pickled_command) > sc._jvm.PythonUtils.getBroadcastThreshold(sc._jsc): # Default 1M
2465 # The broadcast will have same life cycle as created PythonRDD

/databricks/spark/python/pyspark/serializers.py in dumps(self, obj) 709 msg = "Could not serialize object: %s: %s" % (e.class.name, emsg) 710 cloudpickle.print_exec(sys.stderr) --> 711 raise pickle.PicklingError(msg) 712 713

PicklingError: Could not serialize object: TypeError: can't pickle _thread._local objects


Solution

  • I am not expert in DataBricks or Spark, but pickling functions from the local notebook context is always problematic when you are touching complex objects like the service object. In this particular case, I would recommend removing the dependency on the azureML service object and just use requests to call the service.

    Pull the key from the service:

    # retrieve the API keys. two keys were generated.
    key1, key2 = service.get_keys()
    scoring_uri = service.scoring_uri
    

    You should be able to use these strings in the UDF directly without pickling issues -- here is an example of how you would call the service with just requests. Below applied to your UDF:

    import requests, json
    def predictModelValue2(summary, modelName, modelLabel):  
      input_data = json.dumps({"summary": summary, "modelName":, ....})
    
      headers = {'Content-Type':'application/json', 'Authorization': 'Bearer ' + key1}
    
      # call the service for scoring
      resp = requests.post(scoring_uri, input_data, headers=headers)
    
      return resp.text[1]
    
    

    On a side node, though: your UDF will be called for each row in your data frame and each time it will make a network call -- that will be very slow. I would recommend looking for ways to batch the execution. As you can see from your constructed json service.run will accept an array of items, so you should call it in batches of 100s or so.