Search code examples
apache-sparkpysparkspark3

Pyspark.ml - Error when loading model and Pipeline


I want to import a trained pyspark model (or pipeline) into a pyspark script. I trained a decision tree model like so:

from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.feature import StringIndexer

# Create assembler and labeller for spark.ml format preperation
assembler = VectorAssembler(inputCols = requiredFeatures, outputCol = 'features')
label_indexer = StringIndexer(inputCol='measurement_status', outputCol='indexed_label')

# Apply transformations
eq_df_labelled = label_indexer.fit(eq_df).transform(eq_df)
eq_df_labelled_featured = assembler.transform(eq_df_labelled)

# Split into training and testing datasets
(training_data, test_data) = eq_df_labelled_featured.randomSplit([0.75, 0.25])

# Create a decision tree algorithm
dtree = DecisionTreeClassifier(
    labelCol ='indexed_label',
    featuresCol = 'features',
    maxDepth = 5,
    minInstancesPerNode=1,
    impurity = 'gini',
    maxBins=32,
    seed=None
)

# Fit classifier object to training data
dtree_model = dtree.fit(training_data)

# Save model to given directory
dtree_model.save("models/dtree")

All of the code above works without any erros. The problem is, when I try to load this model (on the same or on another pyspark application), using:

from pyspark.ml.classification import DecisionTreeClassifier

imported_model = DecisionTreeClassifier()
imported_model.load("models/dtree")

I get the following error:

---------------------------------------------------------------------------
Py4JJavaError                             Traceback (most recent call last)
<ipython-input-4-b283bc2da75f> in <module>
      2 
      3 imported_model = DecisionTreeClassifier()
----> 4 imported_model.load("models/dtree")
      5 
      6 #lodel = DecisionTreeClassifier.load("models/dtree-test/")

~/.local/lib/python3.6/site-packages/pyspark/ml/util.py in load(cls, path)
    328     def load(cls, path):
    329         """Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
--> 330         return cls.read().load(path)
    331 
    332 

~/.local/lib/python3.6/site-packages/pyspark/ml/util.py in load(self, path)
    278         if not isinstance(path, basestring):
    279             raise TypeError("path should be a basestring, got type %s" % type(path))
--> 280         java_obj = self._jread.load(path)
    281         if not hasattr(self._clazz, "_from_java"):
    282             raise NotImplementedError("This Java ML type cannot be loaded into Python currently: %r"

~/.local/lib/python3.6/site-packages/py4j/java_gateway.py in __call__(self, *args)
   1303         answer = self.gateway_client.send_command(command)
   1304         return_value = get_return_value(
-> 1305             answer, self.gateway_client, self.target_id, self.name)
   1306 
   1307         for temp_arg in temp_args:

~/.local/lib/python3.6/site-packages/pyspark/sql/utils.py in deco(*a, **kw)
    126     def deco(*a, **kw):
    127         try:
--> 128             return f(*a, **kw)
    129         except py4j.protocol.Py4JJavaError as e:
    130             converted = convert_exception(e.java_exception)

~/.local/lib/python3.6/site-packages/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
    326                 raise Py4JJavaError(
    327                     "An error occurred while calling {0}{1}{2}.\n".
--> 328                     format(target_id, ".", name), value)
    329             else:
    330                 raise Py4JError(

Py4JJavaError: An error occurred while calling o39.load.
: java.lang.UnsupportedOperationException: empty collection
    at org.apache.spark.rdd.RDD.$anonfun$first$1(RDD.scala:1439)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
    at org.apache.spark.rdd.RDD.withScope(RDD.scala:388)
    at org.apache.spark.rdd.RDD.first(RDD.scala:1437)
    at org.apache.spark.ml.util.DefaultParamsReader$.loadMetadata(ReadWrite.scala:587)
    at org.apache.spark.ml.util.DefaultParamsReader.load(ReadWrite.scala:465)
    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:498)
    at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
    at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
    at py4j.Gateway.invoke(Gateway.java:282)
    at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
    at py4j.commands.CallCommand.execute(CallCommand.java:79)
    at py4j.GatewayConnection.run(GatewayConnection.java:238)
    at java.lang.Thread.run(Thread.java:748)

I went for this approach because it also didnt work using a Pipeline object. Any ideas about what is happening?

UPDATE

I have realised that this error only occurs when I work with my Spark cluster (one master, two workers using Spark's standalone cluster manager). If I set Spark Session like so (where the master is set to the local one):

spark = SparkSession\
    .builder\
    .config(conf=conf)\
    .appName("MachineLearningTesting")\
    .master("local[*]")\
    .getOrCreate()

I do not get the above error.

Also, I am using Spark 3.0.0, could it be that model importing and exporting in Spark 3 still has bugs?


Solution

  • There were two problems:

    1. SSH authenticated communication must be enabled between all nodes in the cluster. Even though all nodes in my Spark cluster are in the same network, only the master had SSH authentication to the workers and not vise versa.

    2. The model must be available to all nodes in the cluster. This may sound really obvious but I thought that the model files need to only be available to the master who then diffuses this to the worker nodes. In other words, when you load the model like so:

    from pyspark.ml.classification import DecisionTreeClassifier
    
    imported_model = DecisionTreeClassifier()
    imported_model.load("models/dtree")
    

    The file /absoloute_path/models/dtree must exist on every machine in the cluster. This made me understand that in production contexts, the models are probably accessed via an external shared file system.

    These two steps solved my problem of loading pyspark models into a Spark application running on a cluster.