Search code examples
pythonapache-sparknltkpysparkapache-spark-ml

Create a custom Transformer in PySpark ML


I am new to Spark SQL DataFrames and ML on them (PySpark). How can I create a custom tokenizer, which for example removes stop words and uses some libraries from ? Can I extend the default one?


Solution

  • Can I extend the default one?

    Not really. Default Tokenizer is a subclass of pyspark.ml.wrapper.JavaTransformer and, same as other transfromers and estimators from pyspark.ml.feature, delegates actual processing to its Scala counterpart. Since you want to use Python you should extend pyspark.ml.pipeline.Transformer directly.

    import nltk
    
    from pyspark import keyword_only  ## < 2.0 -> pyspark.ml.util.keyword_only
    from pyspark.ml import Transformer
    from pyspark.ml.param.shared import HasInputCol, HasOutputCol, Param, Params, TypeConverters
    # Available in PySpark >= 2.3.0 
    from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable  
    from pyspark.sql.functions import udf
    from pyspark.sql.types import ArrayType, StringType
    
    class NLTKWordPunctTokenizer(
            Transformer, HasInputCol, HasOutputCol,
            # Credits https://stackoverflow.com/a/52467470
            # by https://stackoverflow.com/users/234944/benjamin-manns
            DefaultParamsReadable, DefaultParamsWritable):
    
        stopwords = Param(Params._dummy(), "stopwords", "stopwords",
                          typeConverter=TypeConverters.toListString)
    
    
        @keyword_only
        def __init__(self, inputCol=None, outputCol=None, stopwords=None):
            super(NLTKWordPunctTokenizer, self).__init__()
            self.stopwords = Param(self, "stopwords", "")
            self._setDefault(stopwords=[])
            kwargs = self._input_kwargs
            self.setParams(**kwargs)
    
        @keyword_only
        def setParams(self, inputCol=None, outputCol=None, stopwords=None):
            kwargs = self._input_kwargs
            return self._set(**kwargs)
    
        def setStopwords(self, value):
            return self._set(stopwords=list(value))
    
        def getStopwords(self):
            return self.getOrDefault(self.stopwords)
    
        # Required in Spark >= 3.0
        def setInputCol(self, value):
            """
            Sets the value of :py:attr:`inputCol`.
            """
            return self._set(inputCol=value)
    
        # Required in Spark >= 3.0
        def setOutputCol(self, value):
            """
            Sets the value of :py:attr:`outputCol`.
            """
            return self._set(outputCol=value)
    
        def _transform(self, dataset):
            stopwords = set(self.getStopwords())
    
            def f(s):
                tokens = nltk.tokenize.wordpunct_tokenize(s)
                return [t for t in tokens if t.lower() not in stopwords]
    
            t = ArrayType(StringType())
            out_col = self.getOutputCol()
            in_col = dataset[self.getInputCol()]
            return dataset.withColumn(out_col, udf(f, t)(in_col))
    

    Example usage (data from ML - Features):

    sentenceDataFrame = spark.createDataFrame([
      (0, "Hi I heard about Spark"),
      (0, "I wish Java could use case classes"),
      (1, "Logistic regression models are neat")
    ], ["label", "sentence"])
    
    tokenizer = NLTKWordPunctTokenizer(
        inputCol="sentence", outputCol="words",  
        stopwords=nltk.corpus.stopwords.words('english'))
    
    tokenizer.transform(sentenceDataFrame).show()
    

    For custom Python Estimator see How to Roll a Custom Estimator in PySpark mllib

    ⚠ This answer depends on internal API and is compatible with Spark 2.0.3, 2.1.1, 2.2.0 or later (SPARK-19348). For code compatible with previous Spark versions please see revision 8.