cancel_safe_futures/coop_cancel.rs
1//! A multi-producer, single-consumer channel for cooperative (explicit) cancellation.
2//!
3//! This is similar in nature to a [`tokio::task::AbortHandle`], except it uses a cooperative model
4//! for cancellation.
5//!
6//! # Motivation
7//!
8//! Executors like Tokio support forcible cancellation for async tasks via facilities like
9//! [`tokio::task::JoinHandle::abort`]. However, this causes cancellations at any arbitrary await
10//! point. This is often not desirable because it can lead to invariant violations.
11//!
12//! For example, consider this code that consists of both the cancel-safe
13//! [`AsyncWriteExt::write_buf`](tokio::io::AsyncWriteExt::write_buf) and some cancel-unsafe code:
14//!
15//! ```
16//! use bytes::Buf;
17//! use std::io::Cursor;
18//! use tokio::{io::AsyncWriteExt, sync::mpsc};
19//!
20//! struct DataWriter {
21//! writer: tokio::fs::File,
22//! bytes_written_channel: mpsc::Sender<usize>,
23//! }
24//!
25//! impl DataWriter {
26//! async fn write(&mut self, cursor: &mut Cursor<&[u8]>) -> std::io::Result<()> {
27//! // Cursor<&[u8]> implements the bytes::Buf trait, which is used by `write_buf`.
28//! while cursor.has_remaining() {
29//! let bytes_written = self.writer.write_buf(cursor).await?; // (1)
30//! self.bytes_written_channel.send(bytes_written).await; // (2)
31//! }
32//!
33//! Ok(())
34//! }
35//! }
36//! ```
37//!
38//! The invariant upheld by `DataWriter` is that if some bytes are written, the corresponding
39//! `bytes_written` is sent over `self.bytes_written_channel`. This means that cancelling at await
40//! point (1) is okay, but cancelling at await point (2) is not.
41//!
42//! If we use [`tokio::task::JoinHandle::abort`] to cancel the task, it is possible that the task is
43//! cancelled at await point (2), breaking the invariant. Instead, we can use cooperative
44//! cancellation with a `select!` loop.
45//!
46//! ```
47//! use bytes::Buf;
48//! use cancel_safe_futures::coop_cancel;
49//! use std::io::Cursor;
50//! use tokio::{io::AsyncWriteExt, sync::mpsc};
51//!
52//! struct DataWriter {
53//! writer: tokio::fs::File,
54//! bytes_written_channel: mpsc::Sender<usize>,
55//! cancel_receiver: coop_cancel::Receiver<()>,
56//! }
57//!
58//! impl DataWriter {
59//! async fn write(&mut self, cursor: &mut Cursor<&[u8]>) -> std::io::Result<()> {
60//! while cursor.has_remaining() {
61//! tokio::select! {
62//! res = self.writer.write_buf(cursor) => {
63//! let bytes_written = res?;
64//! self.bytes_written_channel.send(bytes_written).await;
65//! }
66//! Some(()) = self.cancel_receiver.recv() => {
67//! // A cancellation notice was sent over the
68//! // channel. Cancel here.
69//! println!("cancelling!");
70//! break;
71//! }
72//! }
73//! }
74//!
75//! Ok(())
76//! }
77//! }
78//! ```
79//!
80//! # Attaching a cancel message
81//!
82//! [`Canceler::cancel`] can be used to send a message of any type `T` along with the cancellation
83//! event. This message is received via the `Some` variant of [`Receiver::recv`].
84//!
85//! For a given [`Receiver`], only the first message sent via any corresponding [`Canceler`] is
86//! received. Subsequent calls to [`Receiver::recv`] will always return `None`, no matter whether
87//! further cancellation messages are sent. (This can change in the future if there's a good use
88//! case for it.)
89//!
90//! # Notes
91//!
92//! This module implements "fan-in" cancellation -- it supports many cancelers but only one
93//! receiver. For "fan-out" cancellation with one sender and many receivers, consider using the
94//! [`drain`](https://docs.rs/drain) crate. This module and `drain` can be combined: create a task
95//! that listens to a [`Receiver`], and notify downstream receivers via `drain` in that task.
96
97use crate::support::statically_unreachable;
98use core::{
99 fmt,
100 future::Future,
101 marker::PhantomData,
102 pin::Pin,
103 task::{ready, Poll},
104};
105use futures_util::FutureExt;
106use tokio::sync::{mpsc, oneshot};
107
108/// Creates and returns a cooperative cancellation pair.
109///
110/// For more information, see [the module documentation](`self`).
111pub fn new_pair<T>() -> (Canceler<T>, Receiver<T>) {
112 let (sender, receiver) = mpsc::unbounded_channel();
113 (
114 Canceler { sender },
115 Receiver {
116 receiver,
117 first_sender: None,
118 },
119 )
120}
121
122/// A cooperative cancellation receiver.
123///
124/// For more information, see [the module documentation](`self`).
125pub struct Receiver<T> {
126 receiver: mpsc::UnboundedReceiver<CancelPayload<T>>,
127 // This is cached and stored here until `Self` is dropped. The senders are really just a way to
128 // signal that the cooperative cancel has completed.
129 first_sender: Option<oneshot::Sender<Never>>,
130}
131
132impl<T> fmt::Debug for Receiver<T> {
133 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
134 f.debug_struct("Receiver")
135 .field("receiver", &self.receiver)
136 .field("first_sender", &self.first_sender)
137 .finish()
138 }
139}
140
141impl<T> Receiver<T> {
142 /// Receives a cancellation payload, or `None` if either:
143 ///
144 /// * a message was received in a previous attempt, or
145 /// * all [`Canceler`] instances have been dropped.
146 ///
147 /// It is expected that after the first time `recv()` returns `Some`, the receiver will be
148 /// dropped.
149 pub async fn recv(&mut self) -> Option<T> {
150 if self.first_sender.is_some() {
151 None
152 } else {
153 match self.receiver.recv().await {
154 Some(payload) => {
155 self.first_sender = Some(payload.dropped_sender);
156 Some(payload.message)
157 }
158 None => None,
159 }
160 }
161 }
162}
163
164/// A cooperative cancellation sender.
165///
166/// For more information, see [the module documentation](`self`).
167pub struct Canceler<T> {
168 // This is an unbounded sender to make Self::cancel not async. In general we
169 // don't expect too many messages to ever be sent via this channel.
170 sender: mpsc::UnboundedSender<CancelPayload<T>>,
171}
172
173impl<T> Clone for Canceler<T> {
174 fn clone(&self) -> Self {
175 Self {
176 sender: self.sender.clone(),
177 }
178 }
179}
180
181impl<T> fmt::Debug for Canceler<T> {
182 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
183 f.debug_struct("Canceler")
184 .field("sender", &self.sender)
185 .finish()
186 }
187}
188
189impl<T> Canceler<T> {
190 /// Performs a cancellation with a message.
191 ///
192 /// This sends the message immediately, and returns a [`Waiter`] that can be optionally waited
193 /// against to block until the corresponding [`Receiver`] is dropped.
194 ///
195 /// Only the first message ever sent via any `Canceler` is received by the [`Receiver`].
196 ///
197 /// Returns `Err(message)` if the corresponding [`Receiver`] has already been dropped, which
198 /// means that the cancel operation failed.
199 pub fn cancel(&self, message: T) -> Result<Waiter<T>, T> {
200 let (message, dropped_receiver) = CancelPayload::new(message);
201 match self.sender.send(message) {
202 Ok(()) => Ok(Waiter {
203 dropped_receiver,
204 _marker: PhantomData,
205 }),
206 Err(error) => Err(error.0.message),
207 }
208 }
209}
210
211#[derive(Debug)]
212enum Never {}
213
214/// A future which can be used to optionally block until a [`Receiver`] is dropped.
215///
216/// A [`Waiter`] is purely advisory, and optional to wait on. Dropping this future does
217/// not affect cancellation.
218pub struct Waiter<T> {
219 // dropped_receiver is just a way to signal that the Receiver has been dropped.
220 dropped_receiver: oneshot::Receiver<Never>,
221 _marker: PhantomData<T>,
222}
223
224// oneshot::Receiver is Unpin, and PhantomData is irrelevant to the Unpin-ness of
225// `Waiter`.
226impl<T> Unpin for Waiter<T> {}
227
228impl<T> fmt::Debug for Waiter<T> {
229 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
230 f.debug_struct("Waiter")
231 .field("dropped_receiver", &self.dropped_receiver)
232 .finish()
233 }
234}
235
236impl<T> Future for Waiter<T> {
237 type Output = ();
238
239 fn poll(mut self: Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> Poll<Self::Output> {
240 // Redundant pattern matching is required for statically_unreachable to work.
241 #[allow(clippy::redundant_pattern_matching)]
242 if let Ok(_) = ready!(self.as_mut().dropped_receiver.poll_unpin(cx)) {
243 // Never is uninhabited.
244 statically_unreachable();
245 }
246
247 Poll::Ready(())
248 }
249}
250
251struct CancelPayload<T> {
252 message: T,
253 dropped_sender: oneshot::Sender<Never>,
254}
255
256impl<T> fmt::Debug for CancelPayload<T>
257where
258 T: fmt::Debug,
259{
260 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
261 f.debug_struct("CancelPayload")
262 .field("message", &self.message)
263 .field("dropped_sender", &self.dropped_sender)
264 .finish()
265 }
266}
267
268impl<T> CancelPayload<T> {
269 fn new(message: T) -> (Self, oneshot::Receiver<Never>) {
270 let (dropped_sender, dropped_receiver) = oneshot::channel();
271 (
272 Self {
273 message,
274 dropped_sender,
275 },
276 dropped_receiver,
277 )
278 }
279}