cancel_safe_futures/macros/join_then_try.rs
1/// Waits on multiple concurrent branches for **all** futures to complete, returning Ok(_) or an
2/// error.
3///
4/// Unlike [`tokio::try_join`], this macro does not cancel remaining futures if one of them returns
5/// an error. Instead, this macro runs all futures to completion.
6///
7/// If more than one future produces an error, `join_then_try!` returns the error from the first
8/// future listed in the macro that produces an error.
9///
10/// The `join_then_try!` macro must be used inside of async functions, closures, and blocks.
11///
12/// # Why use `join_then_try`?
13///
14/// Consider what happens if you're wrapping a set of
15/// [`AsyncWriteExt::flush`](tokio::io::AsyncWriteExt::flush) operations.
16///
17/// ```
18/// use tokio::io::AsyncWriteExt;
19///
20/// # #[tokio::main(flavor = "current_thread")]
21/// # async fn main() -> anyhow::Result<()> {
22/// let temp_dir = tempfile::tempdir()?;
23/// let mut file1 = tokio::fs::File::create(temp_dir.path().join("file1")).await?;
24/// let mut file2 = tokio::fs::File::create(temp_dir.path().join("file2")).await?;
25///
26/// // ... write some data to file1 and file2
27///
28/// tokio::try_join!(file1.flush(), file2.flush())?;
29///
30/// # Ok(()) }
31/// ```
32///
33/// If `file1.flush()` returns an error, `file2.flush()` will be cancelled. This is not ideal, since
34/// we'd like to make an effort to flush both files as far as possible.
35///
36/// One way to run all futures to completion is to use the [`tokio::join`] macro.
37///
38/// ```
39/// # use tokio::io::AsyncWriteExt;
40/// # #[tokio::main(flavor = "current_thread")]
41/// # async fn main() -> anyhow::Result<()> {
42/// # let temp_dir = tempfile::tempdir()?;
43/// let mut file1 = tokio::fs::File::create(temp_dir.path().join("file1")).await?;
44/// let mut file2 = tokio::fs::File::create(temp_dir.path().join("file2")).await?;
45///
46/// // tokio::join! is unaware of errors and runs all futures to completion.
47/// let (res1, res2) = tokio::join!(file1.flush(), file2.flush());
48/// res1?;
49/// res2?;
50/// # Ok(()) }
51/// ```
52///
53/// This, too, is not ideal because it requires you to manually handle the results of each future.
54///
55/// The `join_then_try` macro behaves identically to the above `tokio::join` example, except it is
56/// more user-friendly.
57///
58/// ```
59/// # use tokio::io::AsyncWriteExt;
60/// # #[tokio::main(flavor = "current_thread")]
61/// # async fn main() -> anyhow::Result<()> {
62/// # let temp_dir = tempfile::tempdir()?;
63/// let mut file1 = tokio::fs::File::create(temp_dir.path().join("file1")).await?;
64/// let mut file2 = tokio::fs::File::create(temp_dir.path().join("file2")).await?;
65///
66/// // With join_then_try, if one of the operations errors out the other one will still be
67/// // run to completion.
68/// cancel_safe_futures::join_then_try!(file1.flush(), file2.flush())?;
69/// # Ok(()) }
70/// ```
71///
72/// If an error occurs, the error from the first future listed in the macro that errors out will be
73/// returned.
74///
75/// # Notes
76///
77/// The supplied futures are stored inline. This macro is no-std and no-alloc compatible and does
78/// not require allocating a `Vec`.
79///
80/// This adapter does not expose a way to gather and combine all returned errors. Implementing that
81/// is a future goal, but it requires some design work for a generic way to combine errors. To
82/// do that today, use [`tokio::join`] and combine errors at the end.
83///
84/// # Runtime characteristics
85///
86/// By running all async expressions on the current task, the expressions are able to run
87/// **concurrently** but not in **parallel**. This means all expressions are run on the same thread
88/// and if one branch blocks the thread, all other expressions will be unable to continue. If
89/// parallelism is required, spawn each async expression using [`tokio::task::spawn`] and pass the
90/// join handle to `join_then_try!`.
91#[macro_export]
92#[cfg_attr(docsrs, doc(cfg(feature = "macros")))]
93macro_rules! join_then_try {
94 (@ {
95 // One `_` for each branch in the `join_then_try!` macro. This is not used once
96 // normalization is complete.
97 ( $($count:tt)* )
98
99 // The expression `0+1+1+ ... +1` equal to the number of branches.
100 ( $($total:tt)* )
101
102 // Normalized join_then_try! branches
103 $( ( $($skip:tt)* ) $e:expr, )*
104
105 }) => {{
106 use $crate::macros::support::{maybe_done, poll_fn, Future, Pin};
107 use $crate::macros::support::Poll::{Ready, Pending};
108
109 // Safety: nothing must be moved out of `futures`. This is to satisfy
110 // the requirement of `Pin::new_unchecked` called below.
111 //
112 // We can't use the `pin!` macro for this because `futures` is a tuple
113 // and the standard library provides no way to pin-project to the fields
114 // of a tuple.
115 let mut futures = ( $( maybe_done($e), )* );
116
117 // This assignment makes sure that the `poll_fn` closure only has a
118 // reference to the futures, instead of taking ownership of them. This
119 // mitigates the issue described in
120 // <https://internals.rust-lang.org/t/surprising-soundness-trouble-around-pollfn/17484>
121 let mut futures = &mut futures;
122
123 // Each time the future created by poll_fn is polled, a different future will be polled first
124 // to ensure every future passed to join! gets a chance to make progress even if
125 // one of the futures consumes the whole budget.
126 //
127 // This is number of futures that will be skipped in the first loop
128 // iteration the next time.
129 let mut skip_next_time: u32 = 0;
130
131 poll_fn(move |cx| {
132 const COUNT: u32 = $($total)*;
133
134 let mut is_pending = false;
135
136 let mut to_run = COUNT;
137
138 // The number of futures that will be skipped in the first loop iteration
139 let mut skip = skip_next_time;
140
141 skip_next_time = if skip + 1 == COUNT { 0 } else { skip + 1 };
142
143 // This loop runs twice and the first `skip` futures
144 // are not polled in the first iteration.
145 loop {
146 $(
147 if skip == 0 {
148 if to_run == 0 {
149 // Every future has been polled
150 break;
151 }
152 to_run -= 1;
153
154 // Extract the future for this branch from the tuple.
155 let ( $($skip,)* fut, .. ) = &mut *futures;
156
157 // Safety: future is stored on the stack above
158 // and never moved.
159 let mut fut = unsafe { Pin::new_unchecked(fut) };
160
161 // Try polling
162 if fut.as_mut().poll(cx).is_pending() {
163 is_pending = true;
164 }
165 } else {
166 // Future skipped, one less future to skip in the next iteration
167 skip -= 1;
168 }
169 )*
170 }
171
172 if is_pending {
173 Pending
174 } else {
175 Ready(Ok(($({
176 // Extract the future for this branch from the tuple.
177 let ( $($skip,)* fut, .. ) = &mut futures;
178
179 // Safety: future is stored on the stack above
180 // and never moved.
181 let mut fut = unsafe { Pin::new_unchecked(fut) };
182
183 let output = fut.take_output().expect("expected completed future");
184 match output {
185 Ok(output) => output,
186 Err(error) => return Ready(Err(error)),
187 }
188 },)*)))
189 }
190 }).await
191 }};
192
193 // ===== Normalize =====
194
195 (@ { ( $($s:tt)* ) ( $($n:tt)* ) $($t:tt)* } $e:expr, $($r:tt)* ) => {
196 $crate::join_then_try!(@{ ($($s)* _) ($($n)* + 1) $($t)* ($($s)*) $e, } $($r)*)
197 };
198
199 // ===== Entry point =====
200
201 ( $($e:expr),+ $(,)?) => {
202 $crate::join_then_try!(@{ () (0) } $($e,)*)
203 };
204
205 () => { async { Ok(()) }.await }
206}