1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279
//! A multi-producer, single-consumer channel for cooperative (explicit) cancellation.
//!
//! This is similar in nature to a [`tokio::task::AbortHandle`], except it uses a cooperative model
//! for cancellation.
//!
//! # Motivation
//!
//! Executors like Tokio support forcible cancellation for async tasks via facilities like
//! [`tokio::task::JoinHandle::abort`]. However, this causes cancellations at any arbitrary await
//! point. This is often not desirable because it can lead to invariant violations.
//!
//! For example, consider this code that consists of both the cancel-safe
//! [`AsyncWriteExt::write_buf`](tokio::io::AsyncWriteExt::write_buf) and some cancel-unsafe code:
//!
//! ```
//! use bytes::Buf;
//! use std::io::Cursor;
//! use tokio::{io::AsyncWriteExt, sync::mpsc};
//!
//! struct DataWriter {
//! writer: tokio::fs::File,
//! bytes_written_channel: mpsc::Sender<usize>,
//! }
//!
//! impl DataWriter {
//! async fn write(&mut self, cursor: &mut Cursor<&[u8]>) -> std::io::Result<()> {
//! // Cursor<&[u8]> implements the bytes::Buf trait, which is used by `write_buf`.
//! while cursor.has_remaining() {
//! let bytes_written = self.writer.write_buf(cursor).await?; // (1)
//! self.bytes_written_channel.send(bytes_written).await; // (2)
//! }
//!
//! Ok(())
//! }
//! }
//! ```
//!
//! The invariant upheld by `DataWriter` is that if some bytes are written, the corresponding
//! `bytes_written` is sent over `self.bytes_written_channel`. This means that cancelling at await
//! point (1) is okay, but cancelling at await point (2) is not.
//!
//! If we use [`tokio::task::JoinHandle::abort`] to cancel the task, it is possible that the task is
//! cancelled at await point (2), breaking the invariant. Instead, we can use cooperative
//! cancellation with a `select!` loop.
//!
//! ```
//! use bytes::Buf;
//! use cancel_safe_futures::coop_cancel;
//! use std::io::Cursor;
//! use tokio::{io::AsyncWriteExt, sync::mpsc};
//!
//! struct DataWriter {
//! writer: tokio::fs::File,
//! bytes_written_channel: mpsc::Sender<usize>,
//! cancel_receiver: coop_cancel::Receiver<()>,
//! }
//!
//! impl DataWriter {
//! async fn write(&mut self, cursor: &mut Cursor<&[u8]>) -> std::io::Result<()> {
//! while cursor.has_remaining() {
//! tokio::select! {
//! res = self.writer.write_buf(cursor) => {
//! let bytes_written = res?;
//! self.bytes_written_channel.send(bytes_written).await;
//! }
//! Some(()) = self.cancel_receiver.recv() => {
//! // A cancellation notice was sent over the
//! // channel. Cancel here.
//! println!("cancelling!");
//! break;
//! }
//! }
//! }
//!
//! Ok(())
//! }
//! }
//! ```
//!
//! # Attaching a cancel message
//!
//! [`Canceler::cancel`] can be used to send a message of any type `T` along with the cancellation
//! event. This message is received via the `Some` variant of [`Receiver::recv`].
//!
//! For a given [`Receiver`], only the first message sent via any corresponding [`Canceler`] is
//! received. Subsequent calls to [`Receiver::recv`] will always return `None`, no matter whether
//! further cancellation messages are sent. (This can change in the future if there's a good use
//! case for it.)
//!
//! # Notes
//!
//! This module implements "fan-in" cancellation -- it supports many cancelers but only one
//! receiver. For "fan-out" cancellation with one sender and many receivers, consider using the
//! [`drain`](https://docs.rs/drain) crate. This module and `drain` can be combined: create a task
//! that listens to a [`Receiver`], and notify downstream receivers via `drain` in that task.
use crate::support::statically_unreachable;
use core::{
fmt,
future::Future,
marker::PhantomData,
pin::Pin,
task::{ready, Poll},
};
use futures_util::FutureExt;
use tokio::sync::{mpsc, oneshot};
/// Creates and returns a cooperative cancellation pair.
///
/// For more information, see [the module documentation](`self`).
pub fn new_pair<T>() -> (Canceler<T>, Receiver<T>) {
let (sender, receiver) = mpsc::unbounded_channel();
(
Canceler { sender },
Receiver {
receiver,
first_sender: None,
},
)
}
/// A cooperative cancellation receiver.
///
/// For more information, see [the module documentation](`self`).
pub struct Receiver<T> {
receiver: mpsc::UnboundedReceiver<CancelPayload<T>>,
// This is cached and stored here until `Self` is dropped. The senders are really just a way to
// signal that the cooperative cancel has completed.
first_sender: Option<oneshot::Sender<Never>>,
}
impl<T> fmt::Debug for Receiver<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Receiver")
.field("receiver", &self.receiver)
.field("first_sender", &self.first_sender)
.finish()
}
}
impl<T> Receiver<T> {
/// Receives a cancellation payload, or `None` if either:
///
/// * a message was received in a previous attempt, or
/// * all [`Canceler`] instances have been dropped.
///
/// It is expected that after the first time `recv()` returns `Some`, the receiver will be
/// dropped.
pub async fn recv(&mut self) -> Option<T> {
if self.first_sender.is_some() {
None
} else {
match self.receiver.recv().await {
Some(payload) => {
self.first_sender = Some(payload.dropped_sender);
Some(payload.message)
}
None => None,
}
}
}
}
/// A cooperative cancellation sender.
///
/// For more information, see [the module documentation](`self`).
pub struct Canceler<T> {
// This is an unbounded sender to make Self::cancel not async. In general we
// don't expect too many messages to ever be sent via this channel.
sender: mpsc::UnboundedSender<CancelPayload<T>>,
}
impl<T> Clone for Canceler<T> {
fn clone(&self) -> Self {
Self {
sender: self.sender.clone(),
}
}
}
impl<T> fmt::Debug for Canceler<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Canceler")
.field("sender", &self.sender)
.finish()
}
}
impl<T> Canceler<T> {
/// Performs a cancellation with a message.
///
/// This sends the message immediately, and returns a [`Waiter`] that can be optionally waited
/// against to block until the corresponding [`Receiver`] is dropped.
///
/// Only the first message ever sent via any `Canceler` is received by the [`Receiver`].
///
/// Returns `Err(message)` if the corresponding [`Receiver`] has already been dropped, which
/// means that the cancel operation failed.
pub fn cancel(&self, message: T) -> Result<Waiter<T>, T> {
let (message, dropped_receiver) = CancelPayload::new(message);
match self.sender.send(message) {
Ok(()) => Ok(Waiter {
dropped_receiver,
_marker: PhantomData,
}),
Err(error) => Err(error.0.message),
}
}
}
#[derive(Debug)]
enum Never {}
/// A future which can be used to optionally block until a [`Receiver`] is dropped.
///
/// A [`Waiter`] is purely advisory, and optional to wait on. Dropping this future does
/// not affect cancellation.
pub struct Waiter<T> {
// dropped_receiver is just a way to signal that the Receiver has been dropped.
dropped_receiver: oneshot::Receiver<Never>,
_marker: PhantomData<T>,
}
// oneshot::Receiver is Unpin, and PhantomData is irrelevant to the Unpin-ness of
// `Waiter`.
impl<T> Unpin for Waiter<T> {}
impl<T> fmt::Debug for Waiter<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Waiter")
.field("dropped_receiver", &self.dropped_receiver)
.finish()
}
}
impl<T> Future for Waiter<T> {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> Poll<Self::Output> {
// Redundant pattern matching is required for statically_unreachable to work.
#[allow(clippy::redundant_pattern_matching)]
if let Ok(_) = ready!(self.as_mut().dropped_receiver.poll_unpin(cx)) {
// Never is uninhabited.
statically_unreachable();
}
Poll::Ready(())
}
}
struct CancelPayload<T> {
message: T,
dropped_sender: oneshot::Sender<Never>,
}
impl<T> fmt::Debug for CancelPayload<T>
where
T: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("CancelPayload")
.field("message", &self.message)
.field("dropped_sender", &self.dropped_sender)
.finish()
}
}
impl<T> CancelPayload<T> {
fn new(message: T) -> (Self, oneshot::Receiver<Never>) {
let (dropped_sender, dropped_receiver) = oneshot::channel();
(
Self {
message,
dropped_sender,
},
dropped_receiver,
)
}
}