Search code examples
apache-sparkpysparkapache-kafkaprotocol-buffersstreaming

PySpark: How To Deserialise A Proto Payload From A Kafka Message With Variable Message Type


I am trying to read from a Kafka topic that contains messages with different Proto payloads. With the messageName set in the Kafka message key.

But when I try to:

df = spark.readStream.format(constants.KAFKA_INPUT_FORMAT) \
        .options(**options) \
        .load()
df = df.selectExpr("CAST(key AS STRING)").alias('key')
df = df.select(from_protobuf('value', df.key, desc_file_path).alias('value'))

I get the pyspark.errors.exceptions.base.PySparkTypeError: [NOT_ITERABLE] Column is not iterable error.

How can I dynamically set the messageName parameter of the from_protobuf function with the key value of the Kafka message attribute?


Solution

  • step 1 is extract Kafka Key/Value(s): easy I leave it for you

    step 2 is to define a function for dynamic Protobuf deserialization

    from google.protobuf import descriptor_pool, message_factory
    from google.protobuf.descriptor_pb2 import FileDescriptorSet
    
    def deserialize_protobuf(message_name, serialized_payload, descriptor_file):
        with open(descriptor_file, "rb") as f:
            file_descriptor_set = FileDescriptorSet.FromString(f.read())
    
        pool = descriptor_pool.DescriptorPool()
        for file_descriptor_proto in file_descriptor_set.file:
            pool.Add(file_descriptor_proto)
    
        message_descriptor = pool.FindMessageTypeByName(message_name)
        if not message_descriptor:
            raise ValueError(f"Message type {message_name} not found in descriptor.")
    
        message_class = message_factory.MessageFactory(pool).GetPrototype(message_descriptor)
        message = message_class.FromString(serialized_payload)
    
        return message
    

    step 3: Use foreachBatch for dynamic deserialization

    def process_batch(batch_df, batch_id):
        from pyspark.sql.functions import pandas_udf
        from pyspark.sql.types import StringType
    
        @pandas_udf(StringType())
        def protobuf_deserializer(message_names, payloads):
            results = []
            for message_name, payload in zip(message_names, payloads):
                try:
                    # Deserialize using the function above
                    message = deserialize_protobuf(
                        message_name=message_name,
                        serialized_payload=payload,
                        descriptor_file="path/to/descriptor/file"
                    )
                    results.append(message.SerializeToString())
                except Exception as e:
                    results.append(None)
            return pd.Series(results)
    
        deserialized_df = batch_df.withColumn(
            "deserialized_value",
            protobuf_deserializer(col("messageName"), col("value"))
        )
    
        deserialized_df.write.format("parquet").mode("append").save("/path/to/output")
    
    query = df.writeStream \
        .foreachBatch(process_batch) \
        .outputMode("append") \
        .start()
    
    query.awaitTermination()
    

    good luck