cancel_safe_futures/macros/
join_then_try.rs

1/// Waits on multiple concurrent branches for **all** futures to complete, returning Ok(_) or an
2/// error.
3///
4/// Unlike [`tokio::try_join`], this macro does not cancel remaining futures if one of them returns
5/// an error. Instead, this macro runs all futures to completion.
6///
7/// If more than one future produces an error, `join_then_try!` returns the error from the first
8/// future listed in the macro that produces an error.
9///
10/// The `join_then_try!` macro must be used inside of async functions, closures, and blocks.
11///
12/// # Why use `join_then_try`?
13///
14/// Consider what happens if you're wrapping a set of
15/// [`AsyncWriteExt::flush`](tokio::io::AsyncWriteExt::flush) operations.
16///
17/// ```
18/// use tokio::io::AsyncWriteExt;
19///
20/// # #[tokio::main(flavor = "current_thread")]
21/// # async fn main() -> anyhow::Result<()> {
22/// let temp_dir = tempfile::tempdir()?;
23/// let mut file1 = tokio::fs::File::create(temp_dir.path().join("file1")).await?;
24/// let mut file2 = tokio::fs::File::create(temp_dir.path().join("file2")).await?;
25///
26/// // ... write some data to file1 and file2
27///
28/// tokio::try_join!(file1.flush(), file2.flush())?;
29///
30/// # Ok(()) }
31/// ```
32///
33/// If `file1.flush()` returns an error, `file2.flush()` will be cancelled. This is not ideal, since
34/// we'd like to make an effort to flush both files as far as possible.
35///
36/// One way to run all futures to completion is to use the [`tokio::join`] macro.
37///
38/// ```
39/// # use tokio::io::AsyncWriteExt;
40/// # #[tokio::main(flavor = "current_thread")]
41/// # async fn main() -> anyhow::Result<()> {
42/// # let temp_dir = tempfile::tempdir()?;
43/// let mut file1 = tokio::fs::File::create(temp_dir.path().join("file1")).await?;
44/// let mut file2 = tokio::fs::File::create(temp_dir.path().join("file2")).await?;
45///
46/// // tokio::join! is unaware of errors and runs all futures to completion.
47/// let (res1, res2) = tokio::join!(file1.flush(), file2.flush());
48/// res1?;
49/// res2?;
50/// # Ok(()) }
51/// ```
52///
53/// This, too, is not ideal because it requires you to manually handle the results of each future.
54///
55/// The `join_then_try` macro behaves identically to the above `tokio::join` example, except it is
56/// more user-friendly.
57///
58/// ```
59/// # use tokio::io::AsyncWriteExt;
60/// # #[tokio::main(flavor = "current_thread")]
61/// # async fn main() -> anyhow::Result<()> {
62/// # let temp_dir = tempfile::tempdir()?;
63/// let mut file1 = tokio::fs::File::create(temp_dir.path().join("file1")).await?;
64/// let mut file2 = tokio::fs::File::create(temp_dir.path().join("file2")).await?;
65///
66/// // With join_then_try, if one of the operations errors out the other one will still be
67/// // run to completion.
68/// cancel_safe_futures::join_then_try!(file1.flush(), file2.flush())?;
69/// # Ok(()) }
70/// ```
71///
72/// If an error occurs, the error from the first future listed in the macro that errors out will be
73/// returned.
74///
75/// # Notes
76///
77/// The supplied futures are stored inline. This macro is no-std and no-alloc compatible and does
78/// not require allocating a `Vec`.
79///
80/// This adapter does not expose a way to gather and combine all returned errors. Implementing that
81/// is a future goal, but it requires some design work for a generic way to combine errors. To
82/// do that today, use [`tokio::join`] and combine errors at the end.
83///
84/// # Runtime characteristics
85///
86/// By running all async expressions on the current task, the expressions are able to run
87/// **concurrently** but not in **parallel**. This means all expressions are run on the same thread
88/// and if one branch blocks the thread, all other expressions will be unable to continue. If
89/// parallelism is required, spawn each async expression using [`tokio::task::spawn`] and pass the
90/// join handle to `join_then_try!`.
91#[macro_export]
92#[cfg_attr(docsrs, doc(cfg(feature = "macros")))]
93macro_rules! join_then_try {
94    (@ {
95        // One `_` for each branch in the `join_then_try!` macro. This is not used once
96        // normalization is complete.
97        ( $($count:tt)* )
98
99        // The expression `0+1+1+ ... +1` equal to the number of branches.
100        ( $($total:tt)* )
101
102        // Normalized join_then_try! branches
103        $( ( $($skip:tt)* ) $e:expr, )*
104
105    }) => {{
106        use $crate::macros::support::{maybe_done, poll_fn, Future, Pin};
107        use $crate::macros::support::Poll::{Ready, Pending};
108
109        // Safety: nothing must be moved out of `futures`. This is to satisfy
110        // the requirement of `Pin::new_unchecked` called below.
111        //
112        // We can't use the `pin!` macro for this because `futures` is a tuple
113        // and the standard library provides no way to pin-project to the fields
114        // of a tuple.
115        let mut futures = ( $( maybe_done($e), )* );
116
117        // This assignment makes sure that the `poll_fn` closure only has a
118        // reference to the futures, instead of taking ownership of them. This
119        // mitigates the issue described in
120        // <https://internals.rust-lang.org/t/surprising-soundness-trouble-around-pollfn/17484>
121        let mut futures = &mut futures;
122
123        // Each time the future created by poll_fn is polled, a different future will be polled first
124        // to ensure every future passed to join! gets a chance to make progress even if
125        // one of the futures consumes the whole budget.
126        //
127        // This is number of futures that will be skipped in the first loop
128        // iteration the next time.
129        let mut skip_next_time: u32 = 0;
130
131        poll_fn(move |cx| {
132            const COUNT: u32 = $($total)*;
133
134            let mut is_pending = false;
135
136            let mut to_run = COUNT;
137
138            // The number of futures that will be skipped in the first loop iteration
139            let mut skip = skip_next_time;
140
141            skip_next_time = if skip + 1 == COUNT { 0 } else { skip + 1 };
142
143            // This loop runs twice and the first `skip` futures
144            // are not polled in the first iteration.
145            loop {
146            $(
147                if skip == 0 {
148                    if to_run == 0 {
149                        // Every future has been polled
150                        break;
151                    }
152                    to_run -= 1;
153
154                    // Extract the future for this branch from the tuple.
155                    let ( $($skip,)* fut, .. ) = &mut *futures;
156
157                    // Safety: future is stored on the stack above
158                    // and never moved.
159                    let mut fut = unsafe { Pin::new_unchecked(fut) };
160
161                    // Try polling
162                    if fut.as_mut().poll(cx).is_pending() {
163                        is_pending = true;
164                    }
165                } else {
166                    // Future skipped, one less future to skip in the next iteration
167                    skip -= 1;
168                }
169            )*
170            }
171
172            if is_pending {
173                Pending
174            } else {
175                Ready(Ok(($({
176                    // Extract the future for this branch from the tuple.
177                    let ( $($skip,)* fut, .. ) = &mut futures;
178
179                    // Safety: future is stored on the stack above
180                    // and never moved.
181                    let mut fut = unsafe { Pin::new_unchecked(fut) };
182
183                    let output = fut.take_output().expect("expected completed future");
184                    match output {
185                        Ok(output) => output,
186                        Err(error) => return Ready(Err(error)),
187                    }
188                },)*)))
189            }
190        }).await
191    }};
192
193    // ===== Normalize =====
194
195    (@ { ( $($s:tt)* ) ( $($n:tt)* ) $($t:tt)* } $e:expr, $($r:tt)* ) => {
196      $crate::join_then_try!(@{ ($($s)* _) ($($n)* + 1) $($t)* ($($s)*) $e, } $($r)*)
197    };
198
199    // ===== Entry point =====
200
201    ( $($e:expr),+ $(,)?) => {
202        $crate::join_then_try!(@{ () (0) } $($e,)*)
203    };
204
205    () => { async { Ok(()) }.await }
206}