Search code examples
numpypysparkrddnumpy-ndarray

pyspark RDDs strip attributes of numpy subclasses


I've been fighting an unexpected behavior when attempting to construct a subclass of numpy ndarray within a map call to a pyspark RDD. Specifically, the attribute that I added within the ndarray subclass appears to be stripped from the resulting RDD.

The following snippets contain the essence of the issue.

import numpy as np

class MyArray(np.ndarray):
    def __new__(cls,shape,extra=None,*args):
        obj = super().__new__(cls,shape,*args)
        obj.extra = extra
        return obj
    def __array_finalize__(self,obj):
        if obj is None:
            return
        self.extra = getattr(obj,"extra",None)

def shape_to_array(shape):
    rval = MyArray(shape,extra=shape)
    rval[:] = np.arange(np.product(shape)).reshape(shape)
    return rval

If I invoke shape_to_array directly (not under pyspark), it behaves as expected:

x = shape_to_array((2,3,5))
print(x.extra)

outputs:

(2, 3, 5)

But, if I invoke shape_to_array via a map to an RDD of inputs, it goes wonky:

from pyspark.sql import SparkSession
sc = SparkSession.builder.appName("Steps").getOrCreate().sparkContext

rdd = sc.parallelize([(2,3,5),(2,4),(2,5)])
result = rdd.map(shape_to_array).cache()
print(result.map(lambda t:type(t)).collect())
print(result.map(lambda t:t.shape).collect())
print(result.map(lambda t:t.extra).collect())

Outputs:

[<class '__main__.MyArray'>, <class '__main__.MyArray'>, <class '__main__.MyArray'>]

[(2, 3, 5), (2, 4), (2, 5)]

22/10/15 15:48:02 ERROR Executor: Exception in task 7.0 in stage 2.0 (TID 23)
org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/usr/local/Cellar/apache-spark/3.3.0/libexec/python/lib/pyspark.zip/pyspark/worker.py", line 686, in main
    process()
  File "/usr/local/Cellar/apache-spark/3.3.0/libexec/python/lib/pyspark.zip/pyspark/worker.py", line 678, in process
    serializer.dump_stream(out_iter, outfile)
  File "/usr/local/Cellar/apache-spark/3.3.0/libexec/python/lib/pyspark.zip/pyspark/serializers.py", line 273, in dump_stream
    vs = list(itertools.islice(iterator, batch))
  File "/usr/local/Cellar/apache-spark/3.3.0/libexec/python/lib/pyspark.zip/pyspark/util.py", line 81, in wrapper
    return f(*args, **kwargs)
  File "/var/folders/w7/42_p7mcd1y91_tjd0jzr8zbh0000gp/T/ipykernel_94831/2519313465.py", line 1, in <lambda>
AttributeError: 'MyArray' object has no attribute 'extra'

What happened to the extra attribute of the MyArray instances?

Thanks much for any/all suggestions

EDIT: A bit of additional info. If I add logging inside the shape_to_array function just before the return, I can verify that the extra attribute does exist on the DataArray object that is being returned. But when I attempt to access the DataArray elements in the RDD from the main driver, they're gone.


Solution

  • After a night of sleeping on this, I remembered that I have often had issues with pyspark RDDs where the error message had to do the return type not working with pickle.

    I wasn't getting that error message this time because numpy.ndarray does work with pickle. BUT... the __reduce__ and __setstate__ methods of numpy.ndarray known nothing of the added extra attribute on the MyArray subclass. This is where extra was being stripped.

    Adding the following two methods to MyArray solved everything.

    def __reduce__(self):
        mthd,cls,args = super().__reduce__(self)
        return mthd, cls, args + (self.extra,)
    
    def __setstate__(self,args):
        super().__setstate__(args[:-1])
        self.extra = args[-1]
    

    Thank you to anyone who took some time to think about my question.