I want to make my own transformer of features in a DataFrame
, so that I add a column which is, for example, a difference between two other columns. I followed this question, but the transformer there operates on one column only. pyspark.ml.Transformer
takes a string as an argument for inputCol
, so of course I can not specify multiple columns.
So basically, what I want to achieve is a _transform()
method that resembles this one:
def _transform(self, dataset):
out_col = self.getOutputCol()
in_col = dataset.select([self.getInputCol()])
# Define transformer logic
def f(col1, col2):
return col1 - col2
t = IntegerType()
return dataset.withColumn(out_col, udf(f, t)(in_col))
How is this possible to do?
I managed to solve the problem by first creating a Vector
out of the set of features that I want to operate on, and then applying the transform on the newly generated vector feature. Below is an example code of how to make a new feature which is a different of two other features:
class MeasurementDifferenceTransformer(Transformer, HasInputCol, HasOutputCol):
@keyword_only
def __init__(self, inputCol=None, outputCol=None):
super(MeasurementDifferenceTransformer, self).__init__()
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
@keyword_only
def setParams(self, inputCol=None, outputCol=None):
kwargs = self.setParams._input_kwargs
return self._set(**kwargs)
def _transform(self, dataset):
out_col = self.getOutputCol()
in_col = dataset[self.getInputCol()]
# Define transformer logic
def f(vector):
return float(vector[0] - vector[1])
t = FloatType()
return dataset.withColumn(out_col, udf(lambda x: f(x), t)(in_col))
To use it, we first instantiate a VectorAssembler
to create the a vector feature:
pair_assembler = VectorAssembler(inputCols=["col1", "col2"], outputCol="cols_vector")
Then we instantiate the transformer:
pair_transformer = MeasurementDifferenceTransformer(inputCol="cols_vector", outputCol="col1_minus_col2")
Finally we transform the data:
pairfeats = pair_assembler.transform(df)
difffeats = pait_transformer.transform(pairfeats)