Skip to main content

cancel_safe_futures/stream/
for_each_concurrent_then_try.rs

1use core::{fmt, num::NonZeroUsize, pin::Pin};
2use futures_core::{
3    future::{FusedFuture, Future},
4    stream::TryStream,
5    task::{Context, Poll},
6};
7use futures_util::stream::{FuturesUnordered, StreamExt};
8use pin_project_lite::pin_project;
9
10pin_project! {
11    /// Future for the
12    /// [`for_each_concurrent_then_try`](super::TryStreamExt::for_each_concurrent_then_try)
13    /// method.
14    #[must_use = "futures do nothing unless you `.await` or poll them"]
15    pub struct ForEachConcurrentThenTry<St: TryStream, Fut, F> {
16        #[pin]
17        stream: Option<St>,
18        f: F,
19        futures: FuturesUnordered<Fut>,
20        limit: Option<NonZeroUsize>,
21        first_error: Option<St::Error>,
22    }
23}
24
25impl<St, Fut, F> fmt::Debug for ForEachConcurrentThenTry<St, Fut, F>
26where
27    St: TryStream + fmt::Debug,
28    Fut: fmt::Debug,
29    St::Error: fmt::Debug,
30{
31    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32        f.debug_struct("ForEachConcurrentThenTry")
33            .field("stream", &self.stream)
34            .field("futures", &self.futures)
35            .field("limit", &self.limit)
36            .field("first_error", &self.first_error)
37            .finish()
38    }
39}
40
41impl<St, Fut, F> FusedFuture for ForEachConcurrentThenTry<St, Fut, F>
42where
43    St: TryStream,
44    F: FnMut(St::Ok) -> Fut,
45    Fut: Future<Output = Result<(), St::Error>>,
46{
47    fn is_terminated(&self) -> bool {
48        self.stream.is_none() && self.futures.is_empty()
49    }
50}
51
52impl<St, Fut, F> ForEachConcurrentThenTry<St, Fut, F>
53where
54    St: TryStream,
55    F: FnMut(St::Ok) -> Fut,
56    Fut: Future<Output = Result<(), St::Error>>,
57{
58    pub(super) fn new(stream: St, limit: Option<usize>, f: F) -> Self {
59        Self {
60            stream: Some(stream),
61            // Note: `limit` = 0 gets ignored.
62            limit: limit.and_then(NonZeroUsize::new),
63            f,
64            futures: FuturesUnordered::new(),
65            first_error: None,
66        }
67    }
68}
69
70impl<St, Fut, F> Future for ForEachConcurrentThenTry<St, Fut, F>
71where
72    St: TryStream,
73    F: FnMut(St::Ok) -> Fut,
74    Fut: Future<Output = Result<(), St::Error>>,
75{
76    type Output = Result<(), St::Error>;
77
78    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
79        let mut this = self.project();
80        loop {
81            let mut made_progress_this_iter = false;
82
83            // Check if we've already created a number of futures greater than `limit`
84            if this
85                .limit
86                .map(|limit| limit.get() > this.futures.len())
87                .unwrap_or(true)
88            {
89                let mut stream_completed = false;
90                let elem = if let Some(stream) = this.stream.as_mut().as_pin_mut() {
91                    match stream.try_poll_next(cx) {
92                        Poll::Ready(Some(Ok(elem))) => {
93                            made_progress_this_iter = true;
94                            Some(elem)
95                        }
96                        Poll::Ready(Some(Err(error))) => {
97                            if this.first_error.is_none() {
98                                *this.first_error = Some(error);
99                            }
100                            None
101                        }
102                        Poll::Ready(None) => {
103                            stream_completed = true;
104                            None
105                        }
106                        Poll::Pending => None,
107                    }
108                } else {
109                    None
110                };
111                if stream_completed {
112                    this.stream.set(None);
113                }
114                if let Some(elem) = elem {
115                    this.futures.push((this.f)(elem));
116                }
117            }
118
119            match this.futures.poll_next_unpin(cx) {
120                Poll::Ready(Some(item)) => {
121                    made_progress_this_iter = true;
122                    if let Err(error) = item {
123                        if this.first_error.is_none() {
124                            *this.first_error = Some(error);
125                        }
126                    }
127                }
128                Poll::Ready(None) => {
129                    if this.stream.is_none() {
130                        return Poll::Ready(this.first_error.take().map_or(Ok(()), Err));
131                    }
132                }
133                Poll::Pending => {}
134            }
135
136            if !made_progress_this_iter {
137                return Poll::Pending;
138            }
139        }
140    }
141}