1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
//! A multi-producer, single-consumer channel for cooperative (explicit) cancellation.
//!
//! This is similar in nature to a [`tokio::task::AbortHandle`], except it uses a cooperative model
//! for cancellation.
//!
//! # Motivation
//!
//! Executors like Tokio support forcible cancellation for async tasks via facilities like
//! [`tokio::task::JoinHandle::abort`]. However, this causes cancellations at any arbitrary await
//! point. This is often not desirable because it can lead to invariant violations.
//!
//! For example, consider this code that consists of both the cancel-safe
//! [`AsyncWriteExt::write_buf`](tokio::io::AsyncWriteExt::write_buf) and some cancel-unsafe code:
//!
//! ```
//! use bytes::Buf;
//! use std::io::Cursor;
//! use tokio::{io::AsyncWriteExt, sync::mpsc};
//!
//! struct DataWriter {
//!     writer: tokio::fs::File,
//!     bytes_written_channel: mpsc::Sender<usize>,
//! }
//!
//! impl DataWriter {
//!     async fn write(&mut self, cursor: &mut Cursor<&[u8]>) -> std::io::Result<()> {
//!         // Cursor<&[u8]> implements the bytes::Buf trait, which is used by `write_buf`.
//!         while cursor.has_remaining() {
//!             let bytes_written = self.writer.write_buf(cursor).await?; // (1)
//!             self.bytes_written_channel.send(bytes_written).await; // (2)
//!         }
//!
//!         Ok(())
//!     }
//! }
//! ```
//!
//! The invariant upheld by `DataWriter` is that if some bytes are written, the corresponding
//! `bytes_written` is sent over `self.bytes_written_channel`. This means that cancelling at await
//! point (1) is okay, but cancelling at await point (2) is not.
//!
//! If we use [`tokio::task::JoinHandle::abort`] to cancel the task, it is possible that the task is
//! cancelled at await point (2), breaking the invariant. Instead, we can use cooperative
//! cancellation with a `select!` loop.
//!
//! ```
//! use bytes::Buf;
//! use cancel_safe_futures::coop_cancel;
//! use std::io::Cursor;
//! use tokio::{io::AsyncWriteExt, sync::mpsc};
//!
//! struct DataWriter {
//!     writer: tokio::fs::File,
//!     bytes_written_channel: mpsc::Sender<usize>,
//!     cancel_receiver: coop_cancel::Receiver<()>,
//! }
//!
//! impl DataWriter {
//!     async fn write(&mut self, cursor: &mut Cursor<&[u8]>) -> std::io::Result<()> {
//!         while cursor.has_remaining() {
//!             tokio::select! {
//!                 res = self.writer.write_buf(cursor) => {
//!                     let bytes_written = res?;
//!                     self.bytes_written_channel.send(bytes_written).await;
//!                 }
//!                 Some(()) = self.cancel_receiver.recv() => {
//!                     // A cancellation notice was sent over the
//!                     // channel. Cancel here.
//!                     println!("cancelling!");
//!                     break;
//!                 }
//!             }
//!         }
//!
//!         Ok(())
//!     }
//! }
//! ```
//!
//! # Attaching a cancel message
//!
//! [`Canceler::cancel`] can be used to send a message of any type `T` along with the cancellation
//! event. This message is received via the `Some` variant of [`Receiver::recv`].
//!
//! For a given [`Receiver`], only the first message sent via any corresponding [`Canceler`] is
//! received. Subsequent calls to [`Receiver::recv`] will always return `None`, no matter whether
//! further cancellation messages are sent. (This can change in the future if there's a good use
//! case for it.)
//!
//! # Notes
//!
//! This module implements "fan-in" cancellation -- it supports many cancelers but only one
//! receiver. For "fan-out" cancellation with one sender and many receivers, consider using the
//! [`drain`](https://docs.rs/drain) crate. This module and `drain` can be combined: create a task
//! that listens to a [`Receiver`], and notify downstream receivers via `drain` in that task.

use crate::support::statically_unreachable;
use core::{
    fmt,
    future::Future,
    marker::PhantomData,
    pin::Pin,
    task::{ready, Poll},
};
use futures_util::FutureExt;
use tokio::sync::{mpsc, oneshot};

/// Creates and returns a cooperative cancellation pair.
///
/// For more information, see [the module documentation](`self`).
pub fn new_pair<T>() -> (Canceler<T>, Receiver<T>) {
    let (sender, receiver) = mpsc::unbounded_channel();
    (
        Canceler { sender },
        Receiver {
            receiver,
            first_sender: None,
        },
    )
}

/// A cooperative cancellation receiver.
///
/// For more information, see [the module documentation](`self`).
pub struct Receiver<T> {
    receiver: mpsc::UnboundedReceiver<CancelPayload<T>>,
    // This is cached and stored here until `Self` is dropped. The senders are really just a way to
    // signal that the cooperative cancel has completed.
    first_sender: Option<oneshot::Sender<Never>>,
}

impl<T> fmt::Debug for Receiver<T> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("Receiver")
            .field("receiver", &self.receiver)
            .field("first_sender", &self.first_sender)
            .finish()
    }
}

impl<T> Receiver<T> {
    /// Receives a cancellation payload, or `None` if either:
    ///
    /// * a message was received in a previous attempt, or
    /// * all [`Canceler`] instances have been dropped.
    ///
    /// It is expected that after the first time `recv()` returns `Some`, the receiver will be
    /// dropped.
    pub async fn recv(&mut self) -> Option<T> {
        if self.first_sender.is_some() {
            None
        } else {
            match self.receiver.recv().await {
                Some(payload) => {
                    self.first_sender = Some(payload.dropped_sender);
                    Some(payload.message)
                }
                None => None,
            }
        }
    }
}

/// A cooperative cancellation sender.
///
/// For more information, see [the module documentation](`self`).
pub struct Canceler<T> {
    // This is an unbounded sender to make Self::cancel not async. In general we
    // don't expect too many messages to ever be sent via this channel.
    sender: mpsc::UnboundedSender<CancelPayload<T>>,
}

impl<T> Clone for Canceler<T> {
    fn clone(&self) -> Self {
        Self {
            sender: self.sender.clone(),
        }
    }
}

impl<T> fmt::Debug for Canceler<T> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("Canceler")
            .field("sender", &self.sender)
            .finish()
    }
}

impl<T> Canceler<T> {
    /// Performs a cancellation with a message.
    ///
    /// This sends the message immediately, and returns a [`Waiter`] that can be optionally waited
    /// against to block until the corresponding [`Receiver`] is dropped.
    ///
    /// Only the first message ever sent via any `Canceler` is received by the [`Receiver`].
    ///
    /// Returns `Err(message)` if the corresponding [`Receiver`] has already been dropped, which
    /// means that the cancel operation failed.
    pub fn cancel(&self, message: T) -> Result<Waiter<T>, T> {
        let (message, dropped_receiver) = CancelPayload::new(message);
        match self.sender.send(message) {
            Ok(()) => Ok(Waiter {
                dropped_receiver,
                _marker: PhantomData,
            }),
            Err(error) => Err(error.0.message),
        }
    }
}

#[derive(Debug)]
enum Never {}

/// A future which can be used to optionally block until a [`Receiver`] is dropped.
///
/// A [`Waiter`] is purely advisory, and optional to wait on. Dropping this future does
/// not affect cancellation.
pub struct Waiter<T> {
    // dropped_receiver is just a way to signal that the Receiver has been dropped.
    dropped_receiver: oneshot::Receiver<Never>,
    _marker: PhantomData<T>,
}

// oneshot::Receiver is Unpin, and PhantomData is irrelevant to the Unpin-ness of
// `Waiter`.
impl<T> Unpin for Waiter<T> {}

impl<T> fmt::Debug for Waiter<T> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("Waiter")
            .field("dropped_receiver", &self.dropped_receiver)
            .finish()
    }
}

impl<T> Future for Waiter<T> {
    type Output = ();

    fn poll(mut self: Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> Poll<Self::Output> {
        // Redundant pattern matching is required for statically_unreachable to work.
        #[allow(clippy::redundant_pattern_matching)]
        if let Ok(_) = ready!(self.as_mut().dropped_receiver.poll_unpin(cx)) {
            // Never is uninhabited.
            statically_unreachable();
        }

        Poll::Ready(())
    }
}

struct CancelPayload<T> {
    message: T,
    dropped_sender: oneshot::Sender<Never>,
}

impl<T> fmt::Debug for CancelPayload<T>
where
    T: fmt::Debug,
{
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("CancelPayload")
            .field("message", &self.message)
            .field("dropped_sender", &self.dropped_sender)
            .finish()
    }
}

impl<T> CancelPayload<T> {
    fn new(message: T) -> (Self, oneshot::Receiver<Never>) {
        let (dropped_sender, dropped_receiver) = oneshot::channel();
        (
            Self {
                message,
                dropped_sender,
            },
            dropped_receiver,
        )
    }
}