Async Rust Patterns: Stream Owned Receiver
When working with asynchronous tasks in Rust, I have often encountered the pattern of having a background task whose job is, given some receiver, to loop on that receiver forever, waiting for messages and then responding to them.
Common examples of this I can think of off the top of my head are:
- Background tasks receiving messages over channels
- A PgListener waiting on NOTIFY events
- A unix or tcp socket listener waiting for connections
The commonality here is that you have some type providing an
asynchronous method like .recv() or similar, and the Whole Job of
the thing is to sit and call that over and over, waiting until it gets
woken up by the executor, doing some processing, and then waiting
again.
The way I usually wind up handling this is with a pattern that I don't have a good name for, but for the sake of this post I'm calling "Stream Owned Receiver." If you read this and have a better name in mind, email me!
Basics: Stream Owned Receiver
So, let's consider a contrived and simple example: a main() function
that creates a receiver then, every 5 seconds, sends the time to it.
The receiver's job is to print the time. Our starting point is below.
For the sake of example, we're going to use tokio streams rather than
the ones in futures because their interface is a bit easier, but
once you have this pattern down, it's easy to adapt it to use
try_unfold() with the futures channels:
use futures::{Stream, StreamExt, stream};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tokio::sync::mpsc;
struct Printer {
rx: mpsc::Receiver<SystemTime>,
}
#[tokio::main]
async fn main() {
let (tx, rx) = mpsc::channel(1);
let printer = Printer { rx };
loop {
let now = SystemTime::now();
tx.send(now).await;
tokio::time::sleep(Duration::from_secs(5)).await;
}
}
Of course, this works, but it hangs on the second loop, since the channel only has room for one message, and nobody has handled the first one.
Here we'll establish our pattern: we're going to convert the Printer
into an infinite stream, and then spawn it into the background. First,
a method to convert it to a stream, using stream::unfold():
impl Printer {
fn into_stream(self) -> impl Stream<Item = ()> {
// The function must impl FnMut, which means it cannot capture state
// from its environment. The easiest way to manage this in this context
// is often just to pass all of `self` into the function.
//
// The function is called in a loop, with its argument being the passed
// state (`self`, which we'll call `this`).
//
// It should return None when the stream is exhausted, otherwise
// Some((yielded, state)), where `yielded` is passed on as the next stream
// result and `state` is passed to the next function call.
stream::unfold(self, async |mut this| {
// Receivers are nice for this pattern, since they return None when
// all Senders are dropped, making the stream terminate naturally
// if no one is left to send messages to it.
this.rx.recv().await.map(|val| (val, this))
})
.map(|time| {
let timestamp = time
.duration_since(UNIX_EPOCH)
.expect("time is after 1970")
.as_secs();
println!("{timestamp}");
})
}
}
Here, we've created a method that converts the Printer into some Stream
that calls .recv() forever, until it eventually returns None once
all Senders are dropped. This is fine, but it's a pain to test, since
it relies on a side effect. Let's make it pure real quick:
impl Printer {
fn into_stream(self) -> impl Stream<Item = u64> {
// The function must impl FnMut, which means it cannot capture state
// from its environment. The easiest way to manage this in this context
// is often just to pass all of `self` into the function.
//
// The function is called in a loop, with its argument being the passed
// state (`self`, which we'll call `this`).
//
// It should return None when the stream is exhausted, otherwise
// Some((yielded, state)), where `yielded` is passed on as the next stream
// result and `state` is passed to the next function call.
stream::unfold(self, async |mut this| {
// Receivers are nice for this pattern, since they return None when
// all Senders are dropped, making the stream terminate naturally
// if no one is left to send messages to it.
this.rx.recv().await.map(|val| (val, this))
})
.map(|time| {
time.duration_since(UNIX_EPOCH)
.expect("time is after 1970")
.as_secs()
})
}
}
This now gives you a stream that's super easy to test. For example:
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn first_two_values() {
let (tx, rx) = mpsc::channel(1);
let printer = Printer { rx };
let mut stream = std::pin::pin!(printer.into_stream());
let (t1, t2) = (SystemTime::now(), SystemTime::now());
tx.send(t1).await.unwrap();
let ts1 = stream.next().await.unwrap();
tx.send(t2).await.unwrap();
let ts2 = stream.next().await.unwrap();
assert_eq!(ts1, t1.duration_since(UNIX_EPOCH).unwrap().as_secs());
assert_eq!(ts2, t2.duration_since(UNIX_EPOCH).unwrap().as_secs());
}
}
Ideally, your .into_stream() method can always return return values
instead of doing side effects, making it very easy to test.
Of course, we still can't actually do anything with it, so let's
tie it all together with a .run() method:
impl Printer {
async fn run(self) {
self.into_stream()
.for_each(async |val| {
println!("{val}");
})
.await
}
fn into_stream(self) -> impl Stream<Item = u64> {...}
}
The .run() method is super simple and only does side effects! This
lets you put any complicated logic into .into_stream() and test its
output values easily, and not worry as much about testing .run().
We can call it from main():
#[tokio::main]
async fn main() {
let (tx, rx) = mpsc::channel(1);
let printer = Printer { rx };
tokio::spawn(printer.run());
loop {
let now = SystemTime::now();
println!("Sending time...");
tx.send(now).await.unwrap();
tokio::time::sleep(Duration::from_secs(5)).await;
}
}
If you run your project with cargo run now, it should be doing its job:
Sending time...
1778454540
Sending time...
1778454545
Sending time...
1778454550
And that's the basics of it! In the rest of this post, we'll talk about some expansions on this basic pattern, including parallelizing your futures and handling shutdown.
Concurrency and Parallelization
One of the nicest things about this pattern is how amenable it is to adding concurrency and/or parallelization. Let's make it so that we work on up to 5 incoming messages at once, but still return them in order. I'm going to drop comments that have already been covered above for brevity, so go back up and look if you need to:
impl Printer {
fn into_stream(self) -> impl Stream<Item = u64> {
stream::unfold(self, async |mut this| {
this.rx.recv().await.map(|val| (val, this))
})
.map(async |time| {
time.duration_since(UNIX_EPOCH)
.expect("time is after 1970")
.as_secs()
})
// buffer up to five messages at a time, returning in the
// order the messages were received
.buffered(5)
}
}
There are only two changes here: the first is the async keyword
making the StreamExt::map() return an AsyncFn (a function that
returns a future) instead of just a Fn (a function that returns a
value), and the second is .buffered(5). Updating the map() call to
return a future is necessary so that buffered() or
buffer_unordered() can push them onto a FuturesOrdered or
FuturesUnordered under the hood.
And that's all it takes to add concurrency.
Parallelism is only moderately more complicated:
impl Printer {
fn into_stream(self) -> impl Stream<Item = u64> {
stream::unfold(self, async |mut this| {
this.rx.recv().await.map(|val| (val, this))
})
.map(async |time| {
time.duration_since(UNIX_EPOCH)
.expect("time is after 1970")
.as_secs()
})
// Return a future that spawns the task into the background,
// awaits it, and unwraps it.
// See subsequent section for handling the join error
.map(async |fut| tokio::spawn(fut).await.unwrap())
// Buffer up to five of those futures at a time
.buffered(5)
}
}
And there you have it, a stream that, if it was doing something more complicated, could easily saturate five cores.
Error Handling
Above, I'm just choosing to unwrap the potential error in joining
the Future, but a more idiomatic approach would be to return a TryStream,
which is just an alias for a result-returning Stream.
This is a pretty easy change to make in the function signature, and you will then need to handle the error as you like when consuming it.
impl Printer {
fn into_stream(self) -> impl Stream<Item = Result<u64, tokio::task::JoinError>> {
stream::unfold(self, async |mut this| {
this.rx.recv().await.map(|val| (val, this))
})
.map(async |time| {
time.duration_since(UNIX_EPOCH)
.expect("time is after 1970")
.as_secs()
})
.map(tokio::spawn)
.buffered(5)
}
}
Cancellation and Graceful Shutdown
This whole thing is drop-safe, and will be automatically dropped when any associated Sender(s) are dropped. However, in this case, the drop will not wait for any items remaining in the queue to be processed before dropping.
If you want to handle graceful shutdown via signal handlers or similar,
you will want to retain the JoinHandle
from when you spawn the task into the background. From there, you can
drop() the sender and await the handle, preferably with a timeout:
the handle will return once any pending items are processed.
That looks something like this:
#[tokio::main]
async fn main() {
let (tx, rx) = mpsc::channel(1);
let printer = Printer { rx };
let handle = tokio::spawn(printer.run());
let now = SystemTime::now();
println!("Sending time...");
for _ in 0..5 {
tx.send(now).await.unwrap();
}
// Dropping the tx means the receiver will get None once it finishes anything
// left in the queue.
drop(tx);
handle.await.unwrap();
}
Wrapping Up
This Stream Owned Receiver pattern may very well have already been obvious to many of you, but I thought it was a really nice pattern when I hashed it out for myself. Hopefully someone finds it useful.
It is also worth noting that much of this functionality (sending, dropping, etc.) can and probably should be wrapped in more semantically meaningful wrapper types both in order to prevent misuse and to make the code easier to read.