Search code examples
javaapache-sparkapache-kafkaspark-streaming

Get number of partitions for a topic before creating a direct stream using kafka and spark streaming?


I have the following code that creates a direct stream using kafka connector for spark.

public abstract class MessageConsumer<T> 
{
    public JavaInputDStream<ConsumerRecord<String, T>> createConsumer(final JavaStreamingContext jsc, 
        final Collection<String> topics, final String servers)
    {
        return KafkaUtils.createDirectStream(
            jsc,
            LocationStrategies.PreferConsistent(),
            ConsumerStrategies.<String, T>Subscribe(topics,
                ConsumerUtils.getKafkaParams(servers, getGroupId(), getDeserializerClassName())));
    }

    protected abstract String getDeserializerClassName();

    protected abstract String getGroupId();
}

This works fine, but now I want to change the logic so the consumer will consume from a specific partition of a topic, as opposed to letting Kafka decide which partition to consume from. I do this by using the same algorithm that the default kafka partitioner uses to determine what partition to send the message to based on the key DefaultPartitioner.toPositive(Utils.murmur2(keyBytes)) % numPartitions;. I then simply assign my consumer to this partition. In order for this to work, I need to know the total number of partitions available for the topic. However I do not know how to get this information using the kafka/spark streaming API.

I have been able to get this to work with other parts of my application that don't use Spark, but I am unclear of how to achieve this when using Spark. The only way I can see to achieving this is by creating another consumer before creating the direct stream, and using it to get the total number of partitions, and then closing this consumer. See the below code for this implementation:

public abstract class MessageConsumer<T> 
{
    public JavaInputDStream<ConsumerRecord<String, T>> createConsumer(final JavaStreamingContext jsc, 
        final String topic, final String servers, final String groundStation)
    {
        final Properties props = ConsumerUtils.getKafkaParams(servers, getGroupId(), getDeserializerClassName());
        final Consumer<String, T> tempConsumer = new KafkaConsumer<>(props);
        final int numPartitions = tempConsumer.partitionsFor(topic).size();
        final int partition = calculateKafkaPartition(groundStation.getBytes(), numPartitions);
        final TopicPartition topicPartition = new TopicPartition(topic, partition);
        tempConsumer.close();

        return KafkaUtils.createDirectStream(
            jsc,
            LocationStrategies.PreferConsistent(),
            ConsumerStrategies.<String, T>Assign(Collections.singletonList(topicPartition),
                ConsumerUtils.getKafkaParams(servers, getGroupId(), getDeserializerClassName())));
    }

    protected abstract String getDeserializerClassName();

    protected abstract String getGroupId();

    private static int calculateKafkaPartition(final byte[] keyBytes, final int numberOfPartitions)
    {
        return Utils.toPositive(Utils.murmur2(keyBytes)) % numberOfPartitions;
    }
}

This doesn't seem right to me at all, surely there is a better way to do this?


Solution

  • You'd use Kafka's AdminClient to describe the topic. There's no Spark API for such information