Kafka的消费者类KafkaConsumer是非线程安全的,那如何实现多线程的Consumer呢?先了解一下一般Consumer的流程。

1.jpg

如上图:

  1. 通过poll方法从kafka集群拉取数据;
  2. 处理数据
  3. 提交offset(如果开启了自动提交enable.auto.commit=true,则每次poll的时候会自动提交上一次poll的offset)

如此往复。翻译成代码类似下面这样:

    Properties properties = new Properties();
    properties.setProperty(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9092");
    properties.setProperty(ConsumerConfig.GROUP_ID_CONFIG, "test");
        ...

    // 创建consumer实例(非线程安全)
    KafkaConsumer<String, String> consumer = new KafkaConsumer<>(properties);
    // 订阅主题
    consumer.subscribe(Collections.singletonList("test"));
    while (true) {
      // 拉取数据
      ConsumerRecords<String, String> records = consumer.poll(Duration.ofMillis(10_000));
      // 处理数据
      for (ConsumerRecord<String, String> record : records) {
        ...
      }
      // 提交offset
      consumer.commitSync();
    }

目前kafka consumer的多线程方案常见的有两种:

  1. Thread per consumer model:即每个线程都有自己的consumer实例,然后在一个线程里面完成数据的获取(poll)、处理(process)、offset提交。
  2. Multi-threaded consumer model:一个线程(也可能是多个)专门用于获取数据,另外一组线程专门用于处理。这种模型没有统一的标准。

下面分别介绍。

Thread-Per-Consumer Model

这种多线程模型是利用Kafka的topic分多个partition的机制来实现并行:每个线程都有自己的consumer实例,负责消费若干个partition。各个线程之间是完全独立的,不涉及任何线程同步和通信,所以实现起来非常简单,使用也最多,像Flink里面用的就是这种模型。比如下面是2个线程消费5个分区的示例图:

2.jpg

用代码实现起来的思路是:先确定线程数,然后将分区数平均分给这些线程。下面是一个示例代码(完整代码见这里 github):

/**
 * @author NiYanchun
 **/
public class ThreadPerConsumerModel {
  public static void main(String[] args) throws InterruptedException {
    Properties config = new Properties();
    config.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9092");
    config.put(ConsumerConfig.GROUP_ID_CONFIG, "test");
    config.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "true");
    config.put(ConsumerConfig.AUTO_COMMIT_INTERVAL_MS_CONFIG, "2000");
    config.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, "org.apache.kafka.common.serialization.StringDeserializer");
    config.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, "org.apache.kafka.common.serialization.StringDeserializer");
    config.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest");

    final String topic = "test";
    int partitionNum = getPartitionNum(topic);
    final int threadNum = 4;
    int partitionNumPerThread = (partitionNum <= threadNum) ? 1 : partitionNum / threadNum + 1;

    ExecutorService threadPool = Executors.newFixedThreadPool(partitionNum);
    List<TopicPartition> topicPartitions = new ArrayList<>(partitionNumPerThread);
    for (int i = 0; i < partitionNum; i++) {
      topicPartitions.add(new TopicPartition(topic, i));
      if ((i + 1) % partitionNumPerThread == 0) {
        threadPool.submit(new Task(new ArrayList<>(topicPartitions), config));
        topicPartitions.clear();
      }
    }
    if (!topicPartitions.isEmpty()) {
      threadPool.submit(new Task(new ArrayList<>(topicPartitions), config));
    }

    threadPool.shutdown();
    threadPool.awaitTermination(Integer.MAX_VALUE, TimeUnit.SECONDS);
  }

  static class Task implements Runnable {
    private final List<TopicPartition> topicPartitions;
    private final Properties config;

    public Task(List<TopicPartition> topicPartitions, Properties config) {
      this.topicPartitions = topicPartitions;
      this.config = config;
    }

    @Override
    public void run() {
      KafkaConsumer<String, String> consumer = new KafkaConsumer<>(config);
      System.out.println(topicPartitions);
      consumer.assign(topicPartitions);
      try {
        while (true) {
          ConsumerRecords<String, String> records = consumer.poll(Duration.ofMillis(10_000));
          for (ConsumerRecord<String, String> record : records) {
            System.out.println(record);
          }
        }
      } catch (Exception e) {
        e.printStackTrace();
      } finally {
        consumer.close();
      }
    }
  }
  
  public static int getPartitionNum(String topic) {
    return 8;
  }
}

上面的代码仅是作为原理说明,所以有些处理都简化或者硬编码了,核心逻辑就是将分区数均分给所有线程,这里不再赘述。

这种模型最大的优点是简单,但并非完美。在这种模型里面,获取数据、处理数据都是在一个线程里面的,如果在处理流程需要耗费很长时间,甚至是不可控的时候就会产生问题。会有什么问题?Kafka有一个参数**max.poll.interval.ms**,它表示的是两次poll调用的最大时间间隔。如果超过这个时间间隔,kafka就会认为这个consumer已经挂了,就会触发consumer group rebalance。这个参数默认值是5分钟,也就是如果某次处理超过5分钟,那就可能导致rebalance。Rebalance有什么问题呢?问题很大,对于Rebalance以后再介绍,这里先简单概括一下:

  1. 引发Consumer Group Rebalance操作主要有3种情况:(1)consumer数量发生变化(2)consumer订阅的主题发生变化(2)主题的分区数发生变化
  2. Rebalance是一个比较重的操作,它需要consumer group的所有活着的consumer参与,当consumer比较多的时候,rebalance会很耗时(比如若干小时);而且在没有完成之前,大家都是干不了活的,体现在业务上就是停止处理了。
  3. Rebalance其实就是重新划分partition,如果是自己提交offset的话,处理不好,就可能产生重复数据,后面再说。

总之,Rebalance操作能避免应该尽可能避免,特别是因为编码不合理产生的应该坚决改掉。为了避免处理时间超过这个最大时间间隔,在仍然使用这种模型的前提下,一般可以通过调整下面的参数一定程度的解决问题:

  1. 控制每次poll的数据不要太大,即修改max.poll.records参数的值,默认是500,即在给定时间内一次最多poll 500个record。
  2. max.poll.interval.ms的值改大一些。

但有些场景通过改上面的2个参数也无法解决,而且可能还挺常见,举个例子,比如消费业务数据,处理后写入到外部存储。如果外部存储挂了,在没有恢复之前一般是不应该继续消费Kafka数据的。此时通过调整max.poll.interval.ms的方法就失效了,因为事先是完全不知道应该设置多少的。

另外一个多线程模型可以解决上面这些问题,但会复杂很多。

Multi-Thread Consumer Model

先看下模型的图(注意:多线程模型的设计方式没有统一标准,下面这种只是其中一种而已):

3.jpg

具体处理流程为:poll线程专门负责拉取数据,然后将数据按partition分组,交给处理线程池,每个线程一次只处理一个分区的数据。数据交给处理线程后,poll继续拉取数据。现在有2个问题:

  1. 如何像Thread-Per-Consumer那样,保证一个分区里面的数据有序
  2. 如何提交offset

数据有序性

要保证partition内数据有序,只要避免多个线程并行处理同一个partition的数据即可。在poll线程给线程池分发数据的时候,已经按partition做了分组,也就是保证了一次拉取的数据中同一个partition的数据只会分配给一个线程。现在只要保证分区数据处理完成之前不再拉取该分区的数据,就可以保证数据的有序了。KafkaConsumer类提供了一个pauseresume方法,参数都是分区信息的集合:

public void pause(Collection<TopicPartition> partitions)
public void resume(Collection<TopicPartition> partitions)

调用pause方法后,后续poll的时候,这些被pause的分区不会再返回任何数据,直到被resume。所以,可以在本次partition的数据处理完成之前,pause该partition。

Offset提交

很显然,poll线程和处理线程解耦异步以后,就不能使用自动提交offset了,必须手动提交,否则可能offset提交了,但数据其实还没有处理。提交是poll线程做的,而offset的值则是处理线程才知道的,两者之前需要一个信息传递机制。代码实现的时候需要考虑。

这里插一个关于offset提交的话题。我们知道有2种方式提交offset:

  • 自动提交:将enable.auto.commit设置为true即可;每次poll的时候会自动提交上一次的offset。
  • 手动提交:分为同步提交commitSync和异步提交commitAsync。同步就是阻塞式提交,异步就是非阻塞式。而且二者都支持指定具体partition的offset,也就是可以精细化的自定义offset。

这里不展开介绍具体细节,只讨论一个问题:offset提交和消息传递语义(Message Delivery Semantics)的关系。Kafka提供了3种消息传递保证:

  • at most once:最多一次,即不会产生重复数据,但可能会丢数据
  • at least once:至少一次,即可能会产生重复数据,但不会丢数据
  • exactly once:准确的一次,不多也不少

首先自动提交offset提供的是at least once的保证,因为他是在poll的时候提交上次数据的offset,也就是如果处理本次poll拉取的数据的时候异常了,导致没有执行到下次poll,那这次这些数据的offset就无法提交了,但数据可能已经处理了一部分了。要实现最多一次也很简单,每次poll完数据,先提交offset,提交成功之后再开始处理即可。而要实现exactly once很难,而且这里的exactly once其实仅指消费,但我们一般想要的是全系统数据链上的exactly once。这个问题比较复杂,要实现真正的端到端exactly once,可以去查查flink、dataflow的设计,需要有很多假设和权衡,这里就不展开了。

重点说明一种情况,考虑这样一个流程:poll数据,处理数据,手动同步提交offset,然后再poll。当然也可能会有变种,就是边处理数据,边提交offset,但实质一样。这种情况会不会产生数据重复?答案是:会。其实不管你用同步提交,还是异步提交,都可能产生数据重复。因为系统随时可能因为其它原因产生Rebalance。所以要真正避免重复数据,正确的方式应该是实现ConsumerRebalanceListener接口,监听Rebalance的发生。当监听到发生Rebalance导致当前consumer的partition发生变化时,同步提交offset。但即使这样,其实也不能严格避免,因为系统还可能在任意阶段挂掉。因此这种流程实际提供的也是at least once的保证。所以这里想要表达的意思是,如果你的系统不能容忍有重复数据(在不丢数据的前提下),光靠kafka是做不到的,这里通过一些编码优化只能减少发生的概率,无法杜绝。终极方案应该是在应用层解决。比如幂等设计、使用数据的时候去重等。

代码实现

代码实现分为poll线程的实现(PollThread.java)、处理线程的实现(ProcessThread.java)、测试类Main.java。完整的代码见github。下面对关键地方进行说明。

先看主流程:

// PollThread.java    
@Override
public void run() {
    try {
        // 主流程
        while (!stopped.get()) {
            // 消费数据
            ConsumerRecords<String, String> records = consumer.poll(Duration.ofMillis(10_000));
            // 按partition分组,分发数据给处理线程池
            handleFetchedRecords(records);
            // 检查正在处理的线程
            checkActiveTasks();
            // 提交offset
            commitOffsets();
        }

    } catch (WakeupException we) {
        if (!stopped.get()) {
            throw we;
        }
    } finally {
        consumer.close();
    }
}

主流程比较清晰:

  1. poll数据
  2. 按partition分组,分发数据给处理线程池
  3. 检查正在处理的线程
  4. 提交offset

poll数据没什么可说的,看下handleFetchedRecords

// PollThread.java
private void handleFetchedRecords(ConsumerRecords<String, String> records) {
    if (records.count() > 0) {
        List<TopicPartition> partitionsToPause = new ArrayList<>();
        // 按partition分组
        records.partitions().forEach(partition -> {
            List<ConsumerRecord<String, String>> consumerRecords = records.records(partition);
            // 提交一个分区的数据给处理线程池
            ProcessThread processThread = new ProcessThread(consumerRecords);
            processThreadPool.submit(processThread);
            // 记录分区与处理线程的关系,方便后面查询处理状态
            activeTasks.put(partition, processThread);
        });
        // pause已经在处理的分区,避免同个分区的数据被多个线程同时消费,从而保证分区内数据有序处理
        consumer.pause(partitionsToPause);
    }
}

该方法里面主要做了下面几件事情:

  1. 按partition将数据分组,然后每个分组交给线程池去处理
  2. 记录分区和处理线程的关系,因为后面要查询处理的状态
  3. 将正在处理的分区pause,即这些分区在后续poll中不会返回数据,直到被resume。原因前面说了,如果继续消费,可能会有多个线程同时处理一个分区来的数据,导致分区内数据的处理顺序无法保证。

接下来,先看下处理线程的实现,代码量不大,就全贴上了:

@Slf4j
public class ProcessThread implements Runnable {
  private final List<ConsumerRecord<String, String>> records;
  private final AtomicLong currentOffset = new AtomicLong();
  private volatile boolean stopped = false;
  private volatile boolean started = false;
  private volatile boolean finished = false;
  private final CompletableFuture<Long> completion = new CompletableFuture<>();
  private final ReentrantLock startStopLock = new ReentrantLock();

  public ProcessThread(List<ConsumerRecord<String, String>> records) {
    this.records = records;
  }

  @Override
  public void run() {
    startStopLock.lock();
    try {
      if (stopped) {
        return;
      }
      started = true;
    } finally {
      startStopLock.unlock();
    }

    for (ConsumerRecord<String, String> record : records) {
      if (stopped) {
        break;
      }
      // process record here and make sure you catch all exceptions;
      currentOffset.set(record.offset() + 1);
    }
    finished = true;
    completion.complete(currentOffset.get());
  }

  public long getCurrentOffset() {
    return currentOffset.get();
  }

  public void stop() {
    startStopLock.lock();
    try {
      this.stopped = true;
      if (!started) {
        finished = true;
        completion.complete(currentOffset.get());
      }
    } finally {
      startStopLock.unlock();
    }
  }

  public long waitForCompletion() {
    try {
      return completion.get();
    } catch (InterruptedException | ExecutionException e) {
      return -1;
    }
  }

  public boolean isFinished() {
    return finished;
  }
}

代码逻辑比较简单,很多设计都是在实现“如何优雅停止处理线程”的情况,这里简单说明一下。不考虑进程停止这种情况的话,其实需要停止处理线程的一个主要原因就是可能发生了rebalance,这个分区给其它consumer了。这个时候当前consumer处理线程优雅的停止方式就是完成目前正在处理的那个record,然后提交offset,然后停止。注意是完成当前正在处理的record,而不是完成本次分给该线程的所有records,因为处理完所有的records可能需要的时间会比较长。前面在offset提交的时候已经提到过,手动提交offset的时候,要想处理好rebalance这种情况,需要实现ConsumerRebalanceListener接口。这里PollThread实现了该接口:

@Slf4j
public class PollThread implements Runnable, ConsumerRebalanceListener {
  private final KafkaConsumer<String, String> consumer;
  private final Map<TopicPartition, ProcessThread> activeTasks = new HashMap<>();
  private final Map<TopicPartition, OffsetAndMetadata> offsetsToCommit = new HashMap<>();
    
  // 省略其它代码
    
   @Override
  public void onPartitionsRevoked(Collection<TopicPartition> partitions) {
    // 1. stop all tasks handling records from revoked partitions
    Map<TopicPartition, ProcessThread> stoppedTask = new HashMap<>();
    for (TopicPartition partition : partitions) {
      ProcessThread processThread = activeTasks.remove(partition);
      if (processThread != null) {
        processThread.stop();
        stoppedTask.put(partition, processThread);
      }
    }

    // 2. wait for stopped task to complete processing of current record
    stoppedTask.forEach((partition, processThread) -> {
      long offset = processThread.waitForCompletion();
      if (offset > 0) {
        offsetsToCommit.put(partition, new OffsetAndMetadata(offset));
      }
    });

    // 3. collect offsets for revoked partitions
    Map<TopicPartition, OffsetAndMetadata> revokedPartitionOffsets = new HashMap<>();
    for (TopicPartition partition : partitions) {
      OffsetAndMetadata offsetAndMetadata = offsetsToCommit.remove(partition);
      if (offsetAndMetadata != null) {
        revokedPartitionOffsets.put(partition, offsetAndMetadata);
      }
    }

    // 4. commit offsets for revoked partitions
    try {
      consumer.commitSync(revokedPartitionOffsets);
    } catch (Exception e) {
      log.warn("Failed to commit offsets for revoked partitions!", e);
    }
  }

  @Override
  public void onPartitionsAssigned(Collection<TopicPartition> partitions) {
    // 如果分区之前没有pause过,那执行resume就不会有什么效果
    consumer.resume(partitions);
  }   
    
}

可以看到:

  • 当检测到有分区分配给当前consumer的时候,就会尝试resume这个分区,因为这个分区可能之前被其它consumer pause过。
  • 当检测到有分区要被回收时,执行了下面几个操作:

    1. 查看activeTasks,看是否有线程正在处理这些分区的数据,有的话调用ProcessThreadstop方法将这些处理线程的stopped标志位设置成true。同时记录找到的这些线程。
    2. ProcessThreadrun方法里面每次循环处理record的时候都会检测上一步置位的这个stopped标志位,从而实现完成当前处理的record后就停止的逻辑。然后这一步就是等待这些线程处理结束,拿到处理的offset值,放到待提交offset队列offsetsToCommit里面。
    3. 从待提交队列里面找到要被回收的分区的offset,放到revokedPartitionOffsets里面。
    4. 提交revokedPartitionOffsets里面的offset。

简单总结的话,当收到有某些partition要从当前consumer回收的消息的时候,就从处理线程里面找到正在处理这些分区的线程,然后通知它们处理完正在处理的那一个record之后就退出吧。当然退出的时候需要返回已处理数据的offset,然后PollThread线程提交这些分区的offset。

然后继续回到主流程,之前讨论到了handleFetchedRecords,主要是按partition分区,并将分区数据分发给处理线程,并将这些partition pause。然后看接下来checkActiveTasks

private void checkActiveTasks() {
    List<TopicPartition> finishedTasksPartitions = new ArrayList<>();
    activeTasks.forEach((partition, processThread) -> {
        if (processThread.isFinished()) {
            finishedTasksPartitions.add(partition);
        }
        long offset = processThread.getCurrentOffset();
        if (offset > 0) {
            offsetsToCommit.put(partition, new OffsetAndMetadata(offset));
        }
    });

    finishedTasksPartitions.forEach(activeTasks::remove);
    consumer.resume(finishedTasksPartitions);
}

该方法的主要作用是:找到已经处理完成的partition,并获取对应的offset,放到待提交队列里面。同时resume这些partition。

最后就是commitOffsets了:

private void commitOffsets() {
    try {
        long currentMillis = System.currentTimeMillis();
        if (currentMillis - lastCommitTime > 5000) {
            if (!offsetsToCommit.isEmpty()) {
                consumer.commitSync(offsetsToCommit);
                offsetsToCommit.clear();
            }
            lastCommitTime = currentMillis;
        }
    } catch (Exception e) {
        log.error("Failed to commit offsets!", e);
    }
}

这个方法比较简单,就是按一定的时间间隔同步提交待提交队列里面的offset。

至此,该模型的整体流程就结束了。其实这个模型还隐式的实现了限流:如果后端线程都在处理了,虽然poll线程还会继续poll,但这些partition因为被pause了,所以不会真正返回数据了。这个功能在将poll和处理解耦,分到不同线程以后是很重要的。如果没有限流且后面处理比较慢的话,如果限制了poll和线程池之间传递records的队列大小,那poll的数据要么丢掉,要么就等待,等待的话又会碰到之前可能会超max.poll.interval.ms的问题。但如果不限制的话,队列就会一直变大,最后OOM。

总结

从提供的功能来看,两种多线程模型都实现了基本的功能:

  1. 多线程处理,提高并发
  2. 提供partition内数据有序保证
  3. 提供 at least once语义保证

但第二种模型可以更好的应对处理流程比较慢的需求场景,之所以要处理这种场景根本原因其实还是kafka的max.poll.interval.ms机制,也就是我们不能无限期的阻塞poll调用。另外,如果不需要提供partition内数据有序的话,可以对模型进行改造,改成纯基于数据拆分的多线程模式。目前的实现其实和Thread-Per-Consumer一样,都是基于partition的拆分、并发,只不过将流程从同步改成了异步而已。

总体而言,两种模型各有利弊,但如果Thread-Per-Consumer模型能满足需求的话,肯定是应该优先使用的,简单明了。毕竟在软件世界,在满足需求的前提下,系统越简单越好。如无必要,勿增实体。

参考: