From 4c0f1b6354b1f1b2f87a6876bfa5d3803804cbb9 Mon Sep 17 00:00:00 2001 From: Dario Nieuwenhuis Date: Sun, 28 Aug 2022 23:32:46 +0200 Subject: [PATCH] futures: add join_array. --- embassy-futures/src/join.rs | 68 +++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/embassy-futures/src/join.rs b/embassy-futures/src/join.rs index 7600d4b8..bc0cb530 100644 --- a/embassy-futures/src/join.rs +++ b/embassy-futures/src/join.rs @@ -1,6 +1,7 @@ //! Wait for multiple futures to complete. use core::future::Future; +use core::mem::MaybeUninit; use core::pin::Pin; use core::task::{Context, Poll}; use core::{fmt, mem}; @@ -252,3 +253,70 @@ where { Join5::new(future1, future2, future3, future4, future5) } + +// ===================================================== + +/// Future for the [`join_array`] function. +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct JoinArray { + futures: [MaybeDone; N], +} + +impl fmt::Debug for JoinArray +where + Fut: Future + fmt::Debug, + Fut::Output: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("JoinArray").field("futures", &self.futures).finish() + } +} + +impl Future for JoinArray { + type Output = [Fut::Output; N]; + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = unsafe { self.get_unchecked_mut() }; + let mut all_done = true; + for f in this.futures.iter_mut() { + all_done &= unsafe { Pin::new_unchecked(f) }.poll(cx); + } + + if all_done { + let mut array: [MaybeUninit; N] = unsafe { MaybeUninit::uninit().assume_init() }; + for i in 0..N { + array[i].write(this.futures[i].take_output()); + } + Poll::Ready(unsafe { (&array as *const _ as *const [Fut::Output; N]).read() }) + } else { + Poll::Pending + } + } +} + +/// Joins the result of an array of futures, waiting for them all to complete. +/// +/// This function will return a new future which awaits all futures to +/// complete. The returned future will finish with a tuple of all results. +/// +/// Note that this function consumes the passed futures and returns a +/// wrapped version of it. +/// +/// # Examples +/// +/// ``` +/// # embassy_futures::block_on(async { +/// +/// async fn foo(n: u32) -> u32 { n } +/// let a = foo(1); +/// let b = foo(2); +/// let c = foo(3); +/// let res = embassy_futures::join::join_array([a, b, c]).await; +/// +/// assert_eq!(res, [1, 2, 3]); +/// # }); +/// ``` +pub fn join_array(futures: [Fut; N]) -> JoinArray { + JoinArray { + futures: futures.map(MaybeDone::Future), + } +}