Search code examples
pythondataframeinheritancepysparksubclass

How to correctly subclass Pyspark DataFrame class by inheritance? Pyspark warns that calling __init__ via super() is not supported


I would like to create a class that directly inherits Pyspark's DataFrame, instead of simply containing a DataFrame as an attribute:

from pyspark.sql import DataFrame


# Desired
class ADataFrame(DataFrame):
    ...


# Not Desired
class NotADataFrame:
    def __init__(self, df: DataFrame):
        self.df = df
        ...
    ...

The accepted answer to this StackOverflow question, as well as the accepted answer (though not the one with the most upvotes) to this StackOverflow question were both great starting points, and led me to this code:

from pyspark.sql import DataFrame


class CustomDataFrame(DataFrame):
    def __init__(self, df: DataFrame) -> None:
        super().__init__(df._jdf, df.sql_ctx)
        ...
    ...

This runs without error (for now), but Pyspark warns me that I'm doing something that is not supported and may break in the future:

UserWarning: DataFrame.sql_ctx is an internal property, and will be removed in future releases. Use DataFrame.sparkSession instead.

UserWarning: DataFrame constructor is internal. Do not directly use it.

I've looked at the source code for DataFrame's __init__, and I can see where the error is coming from, but I don't understand how I'm supposed to use DataFrame.sparkSession instead, like the warning and the source code is instructing me to, when I'm trying to subclass.

Is there a correct/supported way to create a subclass of Pyspark's DataFrame class?

In the question I linked above, the answer with the most upvotes (not the accepted answer) describes a complex method with multiple decorators to add attributes to any class. Is that going to be my best bet?

Any insight would be appreciated.

Thank you.


Solution

  • Solved after another, more careful read through of the source code.

    I just needed to change the line with super() from:

    super().__init__(df._jdf, df.sql_ctx)
    

    to:

    super().__init__(df._jdf, df.sparkSession)