Skip to content

Aris Koliopoulos

Apache Flink: Towards a 20x throughput improvement using in-memory buffers

Flink, Distributed Systems, Scala6 min read

Counting Views

This article explores how in-memory data structures can be leveraged to achieve throughput improvements in stateful transformations in Apache Flink. More specifically, a stateful KeyedProcessFunction with in-memory buffering capabilities is shown to increase throughput by up to 45x in certain data skew scenarios. The code is hosted on GitHub:

https://github.com/ariskk/flink-memory-buffer

The article assumes some knowledge of Apache Flink and its internals. It is assumed throughout the article that the RocksDB state backend is used. A version of this operator is being used in production at DriveTribe. Recently, the Blink planner introduced a similar feature to Flink's Table/SQL API. The implementation in this article is more generalised and customisable, and works with Flink's Scala API.

Motivation

Executing keyBy on a DataStream splits the stream into a number of disjoint logical partitions: one for every key. Flink then uses this key and hash partitioning to guarantee that all records sharing this key will be processed by the same physical node. This property enables Flink to leverage the underlying filesystem for stateful transformations. By guaranteeing that only one parallel operator can process all records for a key, Flink also guarantees that any state this operator needs can be stored locally. State can be accessed quickly and with no network overhead.

Theoretically, each one of those logical partitions can be processed by a different node. This model can scale horizontally to any number of nodes, and thus can achieve astronomical throughput.

As long as records are uniformly distributed over the keyspace, increasing throughput is as simple as adding another node. That said, there are scenarios where record distributions are inherently skewed:

  • Engagements on Elon Musk's tweets vs the average tweet.
  • Uber demand in the City of London vs a smaller town.

Trying to keyBy Elon's userId or London's geohash would likely yield a stream of data orders of magnitude larger than the median. Left untreated, this behaviour can build up backpressure. This can lead to latency and user experience degradation in what often is the most valuable subset of data.

This is a hard problem. At the very core, all records for a key will be processed sequentially by default. In a stateful transformation, all of those records will need to:

  • Read the previous keyed state from RocksDB.
  • Update the state with the new record.
  • Write the result back to RocksDB.

To understand what this looks like in practice, let's have a look at some EBS benchmarks. Using fio to execute random 8kb writes:

1fio --directory=/mnt/bench --ioengine=psync --name fio_test_file --direct=1 --rw=randwrite \
2 --bs=8k --size=500M --numjobs=8 --time_based --runtime=240 --group_reporting --norandommap
3
4 write: io=498080KB, bw=8300.1KB/s, iops=1037, runt= 60003msec
5 clat (usec): min=502, max=1581.3K, avg=7706.99, stdev=55840.92
6 lat (usec): min=502, max=1581.3K, avg=7707.56, stdev=55840.92
7 ...
8 lat (usec) : 500=0.01%, 750=0.63%, 1000=0.36%
9 lat (msec) : 2=0.59%, 4=87.85%, 10=9.63%, 20=0.11%, 50=0.04%
10 lat (msec) : 100=0.04%, 250=0.10%, 500=0.19%, 750=0.20%, 1000=0.16%
11 lat (msec) : 2000=0.06%, >=2000=0.01%

We get an average latency of 7.7ms with a very wide standard deviation of 55ms. Looking at the percentiles, most (88%) IO requests take roughly 5ms. Let's now have a look at the read side.

1fio --directory=/mnt/bench --name fio_test_file --direct=1 --rw=randread \
2 --bs=8k --size=500M --numjobs=8 --time_based --runtime=240 --group_reporting --norandommap
3
4 read : io=2433.9MB, bw=10381KB/s, iops=1297, runt=240069msec
5 clat (usec): min=203, max=1698.1K, avg=6160.87, stdev=36738.12
6 lat (usec): min=203, max=1698.1K, avg=6161.17, stdev=36738.13
7
8 lat (usec) : 250=0.02%, 500=0.73%, 750=5.60%, 1000=10.17%
9 lat (msec) : 2=30.47%, 4=31.48%, 10=12.89%, 20=7.73%, 50=0.06%
10 lat (msec) : 100=0.04%, 250=0.56%, 500=0.10%, 750=0.10%, 1000=0.05%
11 lat (msec) : 2000=0.02%

Quite similar results with a 6.1ms average and a 36ms standard deviation. The distribution is broader this time, with 90% between 1ms and 20ms.

That's a total of ~13.8ms on average, without taking into account processing time or any overhead Flink adds. I appreciate that operating systems cache, RocksDB can cache and some setups use instance stores which are much faster. That said, accessing the IO subsystem frequently is expensive and there is a maximum number of records per key per second that the pipeline can process before backpressure kicks in. In this example above, and assuming that processing takes a trivial amount of time, it is ~70 records per key per second. If the state per key was smaller, this number would be larger. If the transformation was complicated (e.g. it ran an ML model on the record) the number would be smaller.

Elon's tweets often get orders of magnitude more engagement than what a naive setup can support. A keyBy(_.tweetId) would create a stream of thousands of records every second and would struggle to catch up.

A popular solution to this problem is key salting. In this technique, a key is split into a number of sub-keys. This can be illustrated by the following example.

Let's assume we want to run a word count over a big data stream (e.g. Twitter's firehose). Each word is a key.


env.fromCollection(words)
  .map(word => (s"$word-${Random.nextInt(10)}", 1)) 
  .keyingBy { case (saltedWord, _) => saltedWord }
  .reduceWith { case ((saltedWord, acc), _) =>
    (saltedWord, acc + 1)
  }
  .mapWith { case (saltedWord, count) => 
    (saltedWord.split("-").head, count)
  }
  .keyingBy { case (word, _) => word }
  .reduceWith { case ((word, total), (_, partialSum)) =>
    (word, total + partialSum)
  }

Instead of routing all instances of the word "and" (or any word) through a single operator, we can partition the occurrences of the word in 10 buckets (e.g. "and-0", "and-1" ... "and-9"). We can then run a word count on those buckets in parallel and produce partial stats. Then as a next step, we can remove the salt and merge those partial results in the final result. This technique essentially parallelizes the problem, which means more hardware can be thrown at it. It has three shortcomings though:

  • It introduces a network shuffle
  • It adds a second RocksDB access for every record
  • It assumes all operations are commutative

The last point is subtle but quite important. Flink respects the per-key ordering it receives from the source. It will never process two records with the same key (in a keyed stream) in parallel; doing so might violate any causal relationship between them.

For example, a user might create an account and delete it in the next second. Imagine then trying to process a user.created and a user.deleted event in parallel. It would be impossible to account for causal ordering without a) using state and b) having a robust ordering/versioning system in the data.

This article proposes a different approach: instead of accessing state for every incoming record, we can buffer those records in-memory and only update the state on disk once every either a) an interval of time or b) a number of records. This approach essentially trades median latency for overall throughput. It has the following benefits:

  • Reduces total disk accesses (only accesses RocksDB once per key batch)
  • Reduces downstream pressure (by only emitting once every X updates instead of every time)

Let's now look at the implementation.

Implementation

Assuming a KeyedStream[T, K], the signature looks like this:


def buffer[R: TypeInformation](
  triggerInterval: Time,
  maxCount: Int,
  keySelector: T => K,
  processor: Vector[T] => R
): DataStream[R] = ks.process(new KeyedProcessFunction[K, T, R] { ... }

The buffer requires the following arguments:

  • triggerInterval: How often to empty the buffer downstream.
  • maxCount: maximum number of records to hold in the buffer. Assuming the size of T is fixed (or predictable) this can be used to compute the upper bound of memory usage.
  • keySelector: extract K from T
  • processor: a processor to reduce Vector[T] into an R, if desired. Useful if the vector can be aggregated.

This operator needs to hold the buffered items somewhere. To avoid persistent storage overhead, a data structure outside of Flink's RocksDB state backend is needed. ConcurrentHashMap is a good starting point.


val buffer: ConcurrentHashMap[K, Vector[T]] = new ConcurrentHashMap[K, Vector[T]]()
val count: AtomicLong = new AtomicLong(0L)

Note 1: This is operator-level state. One of those will be allocated per Task and will buffer for all keys of a parallel instance. The reason this has to be in a keyed operator is to guarantee that all records sharing the same key will be processed by the same TaskManager, in the correct order, and the processor can operate safely before emitting downstream.

Note 2: Checkpointing can be trivially implemented using ListCheckpointed. That said, this stream couldn't be rescaled, and a custom implementation of CheckpointedFunction would be needed if rescaling is desired.

To empty the buffer downstream:


private def emptyBuffer(out: Collector[R]): Unit =
  Option(buffer).foreach { bu => // make sure it has been initialised
    if (!bu.isEmpty) { // no need to emit if empty
      bu.asScala.foreach {
        case (_, vector) =>
          out.collect(processor(vector)) // process and emit for every key
      }
      bu.clear() // clear the buffer
      count.set(0L) // clear the counter
    }
  }

This is a very basic implementation. For every key, it processes the accumulated vector of records with the processor and emits the result. If the number of keys is large and the number of accumulated records per key is small, this process is inefficient. That said, experimental data (see further below) show that the overhead is negligible. Future work could maintain a set of hot keys based on a rule set and emit other keys directly.

The process function is implemented as follows:


def processElement(
  value: T,
  ctx: KContext#Context,
  out: Collector[R]
): Unit = {
  Option(buffer).foreach { bu =>
    val key = keySelector(value) // extract the key
    val previousVector = bu.getOrDefault(key, Vector.empty)
    bu.put(key, previousVector :+ value) // update the vector for that key
    val newCount = count.incrementAndGet() // increment 
    if (newCount >= maxCount) emptyBuffer(out) // empty to conserve memory
  }
  if (!started.get()) { // if this is the first event, schedule the first timer. 
    val now = ctx.timerService().currentProcessingTime()
    ctx.timerService().registerProcessingTimeTimer(now + triggerMillis)
    started.set(true)
  }
}

The processElement method takes a value T and adds it to the correct Vector in the buffer. If this is the first element, it schedules the first timer.

Note: Vector is used here because it has effectively constant append time.

And finally the onTimer method:


override def onTimer(
  timestamp: Long,
  ctx: KContext#OnTimerContext,
  out: Collector[R]
): Unit = {
  emptyBuffer(out)
  ctx.timerService().registerProcessingTimeTimer(timestamp + triggerMillis)
}

Simply empties the buffer and registers the next timer.

The full implementation can be found here.

Evaluation

To evaluate the solution, a simple view counting experiment was set up. Assuming the following data model:


case class View(viewId: ViewId, userId: UserId, tweetId: TweetId)
object View {
  def generate(
    userCardinality: Int,
    tweetCardinality: Int
  ): View =
    View(
      viewId = s"view-${Random.nextString(10)}",
      userId = s"user-${Random.nextInt(userCardinality)}",
      tweetId = s"tweet-${Random.nextInt(tweetCardinality)}"
    )
}
case class TweetStats(tweetId: TweetId, viewCount: Long, viewerCount: Long)

The purpose was to build two different DataStream[View] => DataStream[TweetStats] functions; one buffered and one unbuffered.

The state in Flink is modeled as two HyperLogLog instances (a probabilistic counter).


case class StatsState(tweetId: TweetId, viewCounter: HLL, viewerCounter: HLL)

The unbuffered implementation:


def computeUnbufferedStats(views: DataStream[View]): DataStream[TweetStats] =
  views
    .map(StatsState.fromView(_))
    .keyBy(_.tweetId)
    .reduce((acc, in) => Semigroup.plus[StatsState](acc, in))
    .map(_.toStats)

The Semigroup abstraction simply provides a function to merge two instances of StatsState, ie def plus(x: StatsState, y: StatsState): StatsState.

The buffered implementation:


def computeBufferedStats(views: DataStream[View]): DataStream[TweetStats] =
  views
    .map(StatsState.fromView(_))
    .keyBy(_.tweetId)
    .buffer(
      Time.milliseconds(100),
      maxCount = 1000,
      keySelector = _.tweetId,
      processor = _.reduce(Semigroup.plus[StatsState](_, _))
    )
    .keyBy(_.tweetId)
    .reduce { (acc, in) =>
      Semigroup.plus(acc, in)
    }
    .map(_.toStats)

State aggregation takes place in two steps: partially, when emitting from the buffer, and then the partial aggregations are merged in reduce.

To evaluate the performance of the two alternatives, throughput (records per second) was measured over different keyspace cardinalities. Starting from 20,000 different keys (ie tweet Ids) down to 1. In all experiments, 100,000 tweet views from 25,000 users were generated. The experiment was run 5 times for each cardinality and stream. The following results were produced.

Results

The smaller the keyspace, the more profound the benefits of buffering records. At 20 keys, the throughput improvement tops at ~45 times. Interestingly, even in higher cardinalities (10K and 20K records) the operator still managed to outperform the naive implementation.

Though the throughput improvement is significant, it comes with the following caveats:

  • As it stands, no record can flow through the operator in less than the triggerInterval. This can be resolved by making buffering dynamic based on incoming traffic.
  • If record sizes are unbounded, memory usage will also be unbounded and the task manager might be killed by the underlying resource manager (K8s/YARN/etc).
  • As it stands, the stream cannot be rescaled. There are solutions to this, but the implementation is not trivial.

Conclusion

Buffering records in a ConcurrentHashMap can be leveraged to increase throughput in data skewness scenarios. Further improvements can be implemented to reduce the median latency penalty and improve the overall robustness of the operator.

If you have any suggestions, observations, or feedback, please feel free to reach out on Twitter.

Thanks for reading!


More

How to join streams in Apache Flink

© 2022 by Aris Koliopoulos. All rights reserved.
Theme by LekoArts