cancel_safe_futures/
coop_cancel.rs

1//! A multi-producer, single-consumer channel for cooperative (explicit) cancellation.
2//!
3//! This is similar in nature to a [`tokio::task::AbortHandle`], except it uses a cooperative model
4//! for cancellation.
5//!
6//! # Motivation
7//!
8//! Executors like Tokio support forcible cancellation for async tasks via facilities like
9//! [`tokio::task::JoinHandle::abort`]. However, this causes cancellations at any arbitrary await
10//! point. This is often not desirable because it can lead to invariant violations.
11//!
12//! For example, consider this code that consists of both the cancel-safe
13//! [`AsyncWriteExt::write_buf`](tokio::io::AsyncWriteExt::write_buf) and some cancel-unsafe code:
14//!
15//! ```
16//! use bytes::Buf;
17//! use std::io::Cursor;
18//! use tokio::{io::AsyncWriteExt, sync::mpsc};
19//!
20//! struct DataWriter {
21//!     writer: tokio::fs::File,
22//!     bytes_written_channel: mpsc::Sender<usize>,
23//! }
24//!
25//! impl DataWriter {
26//!     async fn write(&mut self, cursor: &mut Cursor<&[u8]>) -> std::io::Result<()> {
27//!         // Cursor<&[u8]> implements the bytes::Buf trait, which is used by `write_buf`.
28//!         while cursor.has_remaining() {
29//!             let bytes_written = self.writer.write_buf(cursor).await?; // (1)
30//!             self.bytes_written_channel.send(bytes_written).await; // (2)
31//!         }
32//!
33//!         Ok(())
34//!     }
35//! }
36//! ```
37//!
38//! The invariant upheld by `DataWriter` is that if some bytes are written, the corresponding
39//! `bytes_written` is sent over `self.bytes_written_channel`. This means that cancelling at await
40//! point (1) is okay, but cancelling at await point (2) is not.
41//!
42//! If we use [`tokio::task::JoinHandle::abort`] to cancel the task, it is possible that the task is
43//! cancelled at await point (2), breaking the invariant. Instead, we can use cooperative
44//! cancellation with a `select!` loop.
45//!
46//! ```
47//! use bytes::Buf;
48//! use cancel_safe_futures::coop_cancel;
49//! use std::io::Cursor;
50//! use tokio::{io::AsyncWriteExt, sync::mpsc};
51//!
52//! struct DataWriter {
53//!     writer: tokio::fs::File,
54//!     bytes_written_channel: mpsc::Sender<usize>,
55//!     cancel_receiver: coop_cancel::Receiver<()>,
56//! }
57//!
58//! impl DataWriter {
59//!     async fn write(&mut self, cursor: &mut Cursor<&[u8]>) -> std::io::Result<()> {
60//!         while cursor.has_remaining() {
61//!             tokio::select! {
62//!                 res = self.writer.write_buf(cursor) => {
63//!                     let bytes_written = res?;
64//!                     self.bytes_written_channel.send(bytes_written).await;
65//!                 }
66//!                 Some(()) = self.cancel_receiver.recv() => {
67//!                     // A cancellation notice was sent over the
68//!                     // channel. Cancel here.
69//!                     println!("cancelling!");
70//!                     break;
71//!                 }
72//!             }
73//!         }
74//!
75//!         Ok(())
76//!     }
77//! }
78//! ```
79//!
80//! # Attaching a cancel message
81//!
82//! [`Canceler::cancel`] can be used to send a message of any type `T` along with the cancellation
83//! event. This message is received via the `Some` variant of [`Receiver::recv`].
84//!
85//! For a given [`Receiver`], only the first message sent via any corresponding [`Canceler`] is
86//! received. Subsequent calls to [`Receiver::recv`] will always return `None`, no matter whether
87//! further cancellation messages are sent. (This can change in the future if there's a good use
88//! case for it.)
89//!
90//! # Notes
91//!
92//! This module implements "fan-in" cancellation -- it supports many cancelers but only one
93//! receiver. For "fan-out" cancellation with one sender and many receivers, consider using the
94//! [`drain`](https://docs.rs/drain) crate. This module and `drain` can be combined: create a task
95//! that listens to a [`Receiver`], and notify downstream receivers via `drain` in that task.
96
97use crate::support::statically_unreachable;
98use core::{
99    fmt,
100    future::Future,
101    marker::PhantomData,
102    pin::Pin,
103    task::{ready, Poll},
104};
105use futures_util::FutureExt;
106use tokio::sync::{mpsc, oneshot};
107
108/// Creates and returns a cooperative cancellation pair.
109///
110/// For more information, see [the module documentation](`self`).
111pub fn new_pair<T>() -> (Canceler<T>, Receiver<T>) {
112    let (sender, receiver) = mpsc::unbounded_channel();
113    (
114        Canceler { sender },
115        Receiver {
116            receiver,
117            first_sender: None,
118        },
119    )
120}
121
122/// A cooperative cancellation receiver.
123///
124/// For more information, see [the module documentation](`self`).
125pub struct Receiver<T> {
126    receiver: mpsc::UnboundedReceiver<CancelPayload<T>>,
127    // This is cached and stored here until `Self` is dropped. The senders are really just a way to
128    // signal that the cooperative cancel has completed.
129    first_sender: Option<oneshot::Sender<Never>>,
130}
131
132impl<T> fmt::Debug for Receiver<T> {
133    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
134        f.debug_struct("Receiver")
135            .field("receiver", &self.receiver)
136            .field("first_sender", &self.first_sender)
137            .finish()
138    }
139}
140
141impl<T> Receiver<T> {
142    /// Receives a cancellation payload, or `None` if either:
143    ///
144    /// * a message was received in a previous attempt, or
145    /// * all [`Canceler`] instances have been dropped.
146    ///
147    /// It is expected that after the first time `recv()` returns `Some`, the receiver will be
148    /// dropped.
149    pub async fn recv(&mut self) -> Option<T> {
150        if self.first_sender.is_some() {
151            None
152        } else {
153            match self.receiver.recv().await {
154                Some(payload) => {
155                    self.first_sender = Some(payload.dropped_sender);
156                    Some(payload.message)
157                }
158                None => None,
159            }
160        }
161    }
162}
163
164/// A cooperative cancellation sender.
165///
166/// For more information, see [the module documentation](`self`).
167pub struct Canceler<T> {
168    // This is an unbounded sender to make Self::cancel not async. In general we
169    // don't expect too many messages to ever be sent via this channel.
170    sender: mpsc::UnboundedSender<CancelPayload<T>>,
171}
172
173impl<T> Clone for Canceler<T> {
174    fn clone(&self) -> Self {
175        Self {
176            sender: self.sender.clone(),
177        }
178    }
179}
180
181impl<T> fmt::Debug for Canceler<T> {
182    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
183        f.debug_struct("Canceler")
184            .field("sender", &self.sender)
185            .finish()
186    }
187}
188
189impl<T> Canceler<T> {
190    /// Performs a cancellation with a message.
191    ///
192    /// This sends the message immediately, and returns a [`Waiter`] that can be optionally waited
193    /// against to block until the corresponding [`Receiver`] is dropped.
194    ///
195    /// Only the first message ever sent via any `Canceler` is received by the [`Receiver`].
196    ///
197    /// Returns `Err(message)` if the corresponding [`Receiver`] has already been dropped, which
198    /// means that the cancel operation failed.
199    pub fn cancel(&self, message: T) -> Result<Waiter<T>, T> {
200        let (message, dropped_receiver) = CancelPayload::new(message);
201        match self.sender.send(message) {
202            Ok(()) => Ok(Waiter {
203                dropped_receiver,
204                _marker: PhantomData,
205            }),
206            Err(error) => Err(error.0.message),
207        }
208    }
209}
210
211#[derive(Debug)]
212enum Never {}
213
214/// A future which can be used to optionally block until a [`Receiver`] is dropped.
215///
216/// A [`Waiter`] is purely advisory, and optional to wait on. Dropping this future does
217/// not affect cancellation.
218pub struct Waiter<T> {
219    // dropped_receiver is just a way to signal that the Receiver has been dropped.
220    dropped_receiver: oneshot::Receiver<Never>,
221    _marker: PhantomData<T>,
222}
223
224// oneshot::Receiver is Unpin, and PhantomData is irrelevant to the Unpin-ness of
225// `Waiter`.
226impl<T> Unpin for Waiter<T> {}
227
228impl<T> fmt::Debug for Waiter<T> {
229    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
230        f.debug_struct("Waiter")
231            .field("dropped_receiver", &self.dropped_receiver)
232            .finish()
233    }
234}
235
236impl<T> Future for Waiter<T> {
237    type Output = ();
238
239    fn poll(mut self: Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> Poll<Self::Output> {
240        // Redundant pattern matching is required for statically_unreachable to work.
241        #[allow(clippy::redundant_pattern_matching)]
242        if let Ok(_) = ready!(self.as_mut().dropped_receiver.poll_unpin(cx)) {
243            // Never is uninhabited.
244            statically_unreachable();
245        }
246
247        Poll::Ready(())
248    }
249}
250
251struct CancelPayload<T> {
252    message: T,
253    dropped_sender: oneshot::Sender<Never>,
254}
255
256impl<T> fmt::Debug for CancelPayload<T>
257where
258    T: fmt::Debug,
259{
260    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
261        f.debug_struct("CancelPayload")
262            .field("message", &self.message)
263            .field("dropped_sender", &self.dropped_sender)
264            .finish()
265    }
266}
267
268impl<T> CancelPayload<T> {
269    fn new(message: T) -> (Self, oneshot::Receiver<Never>) {
270        let (dropped_sender, dropped_receiver) = oneshot::channel();
271        (
272            Self {
273                message,
274                dropped_sender,
275            },
276            dropped_receiver,
277        )
278    }
279}