reth_era_downloader/
stream.rs

1use crate::{client::HttpClient, EraClient};
2use futures_util::{stream::FuturesOrdered, FutureExt, Stream, StreamExt};
3use reqwest::Url;
4use reth_fs_util as fs;
5use std::{
6    collections::VecDeque,
7    fmt::{Debug, Formatter},
8    future::Future,
9    path::Path,
10    pin::Pin,
11    task::{Context, Poll},
12};
13
14/// Parameters that alter the behavior of [`EraStream`].
15///
16/// # Examples
17/// ```
18/// use reth_era_downloader::EraStreamConfig;
19///
20/// EraStreamConfig::default().with_max_files(10).with_max_concurrent_downloads(2);
21/// ```
22#[derive(Debug, Clone)]
23pub struct EraStreamConfig {
24    max_files: usize,
25    max_concurrent_downloads: usize,
26}
27
28impl Default for EraStreamConfig {
29    fn default() -> Self {
30        Self { max_files: 5, max_concurrent_downloads: 3 }
31    }
32}
33
34impl EraStreamConfig {
35    /// The maximum amount of downloaded ERA1 files kept in the download directory.
36    pub const fn with_max_files(mut self, max_files: usize) -> Self {
37        self.max_files = max_files;
38        self
39    }
40
41    /// The maximum amount of downloads happening at the same time.
42    pub const fn with_max_concurrent_downloads(mut self, max_concurrent_downloads: usize) -> Self {
43        self.max_concurrent_downloads = max_concurrent_downloads;
44        self
45    }
46}
47
48/// An asynchronous stream of ERA1 files.
49///
50/// # Examples
51/// ```
52/// use futures_util::StreamExt;
53/// use reth_era_downloader::{EraStream, HttpClient};
54///
55/// # async fn import(mut stream: EraStream<impl HttpClient + Clone + Send + Sync + 'static + Unpin>) -> eyre::Result<()> {
56/// while let Some(file) = stream.next().await {
57///     let file = file?;
58///     // Process `file: Box<Path>`
59/// }
60/// # Ok(())
61/// # }
62/// ```
63#[derive(Debug)]
64pub struct EraStream<Http> {
65    download_stream: DownloadStream,
66    starting_stream: StartingStream<Http>,
67}
68
69impl<Http> EraStream<Http> {
70    /// Constructs a new [`EraStream`] that downloads concurrently up to `max_concurrent_downloads`
71    /// ERA1 files to `client` `folder`, keeping their count up to `max_files`.
72    pub fn new(client: EraClient<Http>, config: EraStreamConfig) -> Self {
73        Self {
74            download_stream: DownloadStream {
75                downloads: Default::default(),
76                scheduled: Default::default(),
77                max_concurrent_downloads: config.max_concurrent_downloads,
78                ended: false,
79            },
80            starting_stream: StartingStream {
81                client,
82                files_count: Box::pin(async move { usize::MAX }),
83                next_url: Box::pin(async move { Ok(None) }),
84                recover_index: Box::pin(async move { 0 }),
85                fetch_file_list: Box::pin(async move { Ok(()) }),
86                state: Default::default(),
87                max_files: config.max_files,
88                index: 0,
89                downloading: 0,
90            },
91        }
92    }
93}
94
95/// Contains information about an ERA file.
96pub trait EraMeta: AsRef<Path> {
97    /// Marking this particular ERA file as "processed" lets the caller hint that it is no longer
98    /// going to be using it.
99    ///
100    /// The meaning of that is up to the implementation. The caller should assume that after this
101    /// point is no longer possible to safely read it.
102    fn mark_as_processed(self) -> eyre::Result<()>;
103}
104
105/// Contains information about ERA file that is hosted remotely and represented by a temporary
106/// local file.
107#[derive(Debug)]
108pub struct EraRemoteMeta {
109    path: Box<Path>,
110}
111
112impl EraRemoteMeta {
113    const fn new(path: Box<Path>) -> Self {
114        Self { path }
115    }
116}
117
118impl AsRef<Path> for EraRemoteMeta {
119    fn as_ref(&self) -> &Path {
120        self.path.as_ref()
121    }
122}
123
124impl EraMeta for EraRemoteMeta {
125    /// Removes a temporary local file representation of the remotely hosted original.
126    fn mark_as_processed(self) -> eyre::Result<()> {
127        Ok(fs::remove_file(self.path)?)
128    }
129}
130
131impl<Http: HttpClient + Clone + Send + Sync + 'static + Unpin> Stream for EraStream<Http> {
132    type Item = eyre::Result<EraRemoteMeta>;
133
134    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
135        if let Poll::Ready(fut) = self.starting_stream.poll_next_unpin(cx) {
136            if let Some(fut) = fut {
137                self.download_stream.scheduled.push_back(fut);
138            } else {
139                self.download_stream.ended = true;
140            }
141        }
142
143        let poll = self.download_stream.poll_next_unpin(cx);
144
145        if poll.is_ready() {
146            self.starting_stream.downloaded();
147        }
148
149        poll
150    }
151}
152
153type DownloadFuture =
154    Pin<Box<dyn Future<Output = eyre::Result<EraRemoteMeta>> + Send + Sync + 'static>>;
155
156struct DownloadStream {
157    downloads: FuturesOrdered<DownloadFuture>,
158    scheduled: VecDeque<DownloadFuture>,
159    max_concurrent_downloads: usize,
160    ended: bool,
161}
162
163impl Debug for DownloadStream {
164    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
165        write!(f, "DownloadStream({})", self.downloads.len())
166    }
167}
168
169impl Stream for DownloadStream {
170    type Item = eyre::Result<EraRemoteMeta>;
171
172    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
173        for _ in 0..self.max_concurrent_downloads - self.downloads.len() {
174            if let Some(fut) = self.scheduled.pop_front() {
175                self.downloads.push_back(fut);
176            }
177        }
178
179        let ended = self.ended;
180        let poll = self.downloads.poll_next_unpin(cx);
181
182        if matches!(poll, Poll::Ready(None)) && !ended {
183            cx.waker().wake_by_ref();
184            return Poll::Pending;
185        }
186
187        poll
188    }
189}
190
191struct StartingStream<Http> {
192    client: EraClient<Http>,
193    files_count: Pin<Box<dyn Future<Output = usize> + Send + Sync + 'static>>,
194    next_url: Pin<Box<dyn Future<Output = eyre::Result<Option<Url>>> + Send + Sync + 'static>>,
195    recover_index: Pin<Box<dyn Future<Output = u64> + Send + Sync + 'static>>,
196    fetch_file_list: Pin<Box<dyn Future<Output = eyre::Result<()>> + Send + Sync + 'static>>,
197    state: State,
198    max_files: usize,
199    index: u64,
200    downloading: usize,
201}
202
203impl<Http> Debug for StartingStream<Http> {
204    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
205        write!(
206            f,
207            "StartingStream{{ max_files: {}, index: {}, downloading: {} }}",
208            self.max_files, self.index, self.downloading
209        )
210    }
211}
212
213#[derive(Debug, PartialEq, Default)]
214enum State {
215    #[default]
216    Initial,
217    FetchFileList,
218    RecoverIndex,
219    CountFiles,
220    Missing(usize),
221    NextUrl(usize),
222}
223
224impl<Http: HttpClient + Clone + Send + Sync + 'static + Unpin> Stream for StartingStream<Http> {
225    type Item = DownloadFuture;
226
227    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
228        if self.state == State::Initial {
229            self.fetch_file_list();
230        }
231
232        if self.state == State::FetchFileList {
233            if let Poll::Ready(result) = self.fetch_file_list.poll_unpin(cx) {
234                match result {
235                    Ok(_) => self.recover_index(),
236                    Err(e) => return Poll::Ready(Some(Box::pin(async move { Err(e) }))),
237                }
238            }
239        }
240
241        if self.state == State::RecoverIndex {
242            if let Poll::Ready(index) = self.recover_index.poll_unpin(cx) {
243                self.index = index;
244                self.count_files();
245            }
246        }
247
248        if self.state == State::CountFiles {
249            if let Poll::Ready(downloaded) = self.files_count.poll_unpin(cx) {
250                let max_missing = self.max_files.saturating_sub(downloaded + self.downloading);
251                self.state = State::Missing(max_missing);
252            }
253        }
254
255        if let State::Missing(max_missing) = self.state {
256            if max_missing > 0 {
257                let index = self.index;
258                self.index += 1;
259                self.downloading += 1;
260                self.next_url(index, max_missing);
261            } else {
262                self.count_files();
263            }
264        }
265
266        if let State::NextUrl(max_missing) = self.state {
267            if let Poll::Ready(url) = self.next_url.poll_unpin(cx) {
268                self.state = State::Missing(max_missing - 1);
269
270                return Poll::Ready(url.transpose().map(|url| -> DownloadFuture {
271                    let mut client = self.client.clone();
272
273                    Box::pin(
274                        async move { client.download_to_file(url?).await.map(EraRemoteMeta::new) },
275                    )
276                }));
277            }
278        }
279
280        Poll::Pending
281    }
282}
283
284impl<Http> StartingStream<Http> {
285    const fn downloaded(&mut self) {
286        self.downloading = self.downloading.saturating_sub(1);
287    }
288}
289
290impl<Http: HttpClient + Clone + Send + Sync + 'static> StartingStream<Http> {
291    fn fetch_file_list(&mut self) {
292        let client = self.client.clone();
293
294        Pin::new(&mut self.fetch_file_list)
295            .set(Box::pin(async move { client.fetch_file_list().await }));
296
297        self.state = State::FetchFileList;
298    }
299
300    fn recover_index(&mut self) {
301        let client = self.client.clone();
302
303        Pin::new(&mut self.recover_index)
304            .set(Box::pin(async move { client.recover_index().await }));
305
306        self.state = State::RecoverIndex;
307    }
308
309    fn count_files(&mut self) {
310        let client = self.client.clone();
311
312        Pin::new(&mut self.files_count).set(Box::pin(async move { client.files_count().await }));
313
314        self.state = State::CountFiles;
315    }
316
317    fn next_url(&mut self, index: u64, max_missing: usize) {
318        let client = self.client.clone();
319
320        Pin::new(&mut self.next_url).set(Box::pin(async move { client.url(index).await }));
321
322        self.state = State::NextUrl(max_missing);
323    }
324}